diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..0248cc6 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,4 @@ +generated_types/protos/google/ linguist-generated=true +generated_types/protos/grpc/ linguist-generated=true +generated_types/src/wal_generated.rs linguist-generated=true +trace_exporters/src/thrift/ linguist-generated=true diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8e56943 --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +**/target +**/*.rs.bk +.idea/ +.env +.gdb_history +*.tsm +**/.DS_Store +**/.vscode +heaptrack.* +massif.out.* +perf.data* +perf.svg +perf.txt +valgrind-out.txt +*.pending-snap diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..fd4a283 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,7141 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "addr2line" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "ahash" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" +dependencies = [ + "cfg-if", + "const-random", + "getrandom", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "aho-corasick" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" +dependencies = [ + "memchr", +] + +[[package]] +name = "aliasable" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd" + +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + +[[package]] +name = "allocator-api2" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" + +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anstream" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e2e1ebcb11de5c03c67de28a7df593d32191b44939c482e97702baaaa6ab6a5" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" + +[[package]] +name = "anstyle-parse" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c75ac65da39e5fe5ab759307499ddad880d724eed2f6ce5b5e8a26f4f387928c" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e28923312444cdd728e4738b3f9c9cac739500909bb3d3c94b43551b16517648" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cd54b81ec8d6180e24654d0b371ad22fc3dd083b6ff8ba325b72e00c87660a7" +dependencies = [ + "anstyle", + "windows-sys 0.52.0", +] + +[[package]] +name = "anyhow" +version = "1.0.79" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" + +[[package]] +name = "arc-swap" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6" + +[[package]] +name = "arrayref" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b4930d2cb77ce62f89ee5d5289b4ac049559b1c45539271f5ed4fdc7db34545" + +[[package]] +name = "arrayvec" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" + +[[package]] +name = "arrow" +version = "49.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5bc25126d18a012146a888a0298f2c22e1150327bd2765fc76d710a556b2d614" +dependencies = [ + "ahash", + "arrow-arith", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-csv", + "arrow-data", + "arrow-ipc", + "arrow-json", + "arrow-ord", + "arrow-row", + "arrow-schema", + "arrow-select", + "arrow-string", +] + +[[package]] +name = "arrow-arith" +version = "49.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34ccd45e217ffa6e53bbb0080990e77113bdd4e91ddb84e97b77649810bcf1a7" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "chrono", + "half", + "num", +] + +[[package]] +name = "arrow-array" +version = "49.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bda9acea48b25123c08340f3a8ac361aa0f74469bb36f5ee9acf923fce23e9d" +dependencies = [ + "ahash", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "chrono", + "chrono-tz", + "half", + "hashbrown 0.14.3", + "num", +] + +[[package]] +name = "arrow-buffer" +version = "49.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01a0fc21915b00fc6c2667b069c1b64bdd920982f426079bc4a7cab86822886c" +dependencies = [ + "bytes", + "half", + "num", +] + +[[package]] +name = "arrow-cast" +version = "49.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dc0368ed618d509636c1e3cc20db1281148190a78f43519487b2daf07b63b4a" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "base64", + "chrono", + "comfy-table", + "half", + "lexical-core", + "num", +] + +[[package]] +name = "arrow-csv" +version = "49.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e09aa6246a1d6459b3f14baeaa49606cfdbca34435c46320e14054d244987ca" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "chrono", + "csv", + "csv-core", + "lazy_static", + "lexical-core", + "regex", +] + +[[package]] +name = "arrow-data" +version = "49.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "907fafe280a3874474678c1858b9ca4cb7fd83fb8034ff5b6d6376205a08c634" +dependencies = [ + "arrow-buffer", + "arrow-schema", + "half", + "num", +] + +[[package]] +name = "arrow-flight" +version = "49.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "624e0dcb6b5a7a06222bfd2be3f7e905ce849a6b714ec989f18cdba330c77d38" +dependencies = [ + "arrow-arith", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-ipc", + "arrow-ord", + "arrow-row", + "arrow-schema", + "arrow-select", + "arrow-string", + "base64", + "bytes", + "futures", + "once_cell", + "paste", + "prost", + "tokio", + "tonic", +] + +[[package]] +name = "arrow-ipc" +version = "49.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79a43d6808411886b8c7d4f6f7dd477029c1e77ffffffb7923555cc6579639cd" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "flatbuffers", + "lz4_flex", +] + +[[package]] +name = "arrow-json" +version = "49.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d82565c91fd627922ebfe2810ee4e8346841b6f9361b87505a9acea38b614fee" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "chrono", + "half", + "indexmap 2.2.2", + "lexical-core", + "num", + "serde", + "serde_json", +] + +[[package]] +name = "arrow-ord" +version = "49.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b23b0e53c0db57c6749997fd343d4c0354c994be7eca67152dd2bdb9a3e1bb4" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "half", + "num", +] + +[[package]] +name = "arrow-row" +version = "49.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "361249898d2d6d4a6eeb7484be6ac74977e48da12a4dd81a708d620cc558117a" +dependencies = [ + "ahash", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "half", + "hashbrown 0.14.3", +] + +[[package]] +name = "arrow-schema" +version = "49.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09e28a5e781bf1b0f981333684ad13f5901f4cd2f20589eab7cf1797da8fc167" + +[[package]] +name = "arrow-select" +version = "49.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f6208466590960efc1d2a7172bc4ff18a67d6e25c529381d7f96ddaf0dc4036" +dependencies = [ + "ahash", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "num", +] + +[[package]] +name = "arrow-string" +version = "49.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a48149c63c11c9ff571e50ab8f017d2a7cb71037a882b42f6354ed2da9acc7" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "num", + "regex", + "regex-syntax 0.8.2", +] + +[[package]] +name = "arrow_util" +version = "0.1.0" +dependencies = [ + "ahash", + "arrow", + "chrono", + "comfy-table", + "datafusion", + "hashbrown 0.14.3", + "num-traits", + "once_cell", + "proptest", + "rand", + "regex", + "snafu 0.8.0", + "uuid", + "workspace-hack", +] + +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "assert_cmd" +version = "2.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00ad3f3a942eee60335ab4342358c161ee296829e0d16ff42fc1d6cb07815467" +dependencies = [ + "anstyle", + "bstr", + "doc-comment", + "predicates", + "predicates-core", + "predicates-tree", + "wait-timeout", +] + +[[package]] +name = "assert_matches" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9" + +[[package]] +name = "async-channel" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ca33f4bc4ed1babef42cad36cc1f51fa88be00420404e5b1e80ab1b18f7678c" +dependencies = [ + "concurrent-queue", + "event-listener 4.0.3", + "event-listener-strategy", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-compression" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a116f46a969224200a0a97f29cfd4c50e7534e4b4826bd23ea2c3c533039c82c" +dependencies = [ + "bzip2", + "flate2", + "futures-core", + "futures-io", + "memchr", + "pin-project-lite", + "tokio", + "xz2", + "zstd", + "zstd-safe", +] + +[[package]] +name = "async-lock" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "287272293e9d8c41773cec55e365490fe034813a2f172f502d6ddcf75b2f582b" +dependencies = [ + "event-listener 2.5.3", +] + +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "async-trait" +version = "0.1.77" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "atoi" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" +dependencies = [ + "num-traits", +] + +[[package]] +name = "atomic-write-file" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edcdbedc2236483ab103a53415653d6b4442ea6141baf1ffa85df29635e88436" +dependencies = [ + "nix 0.27.1", + "rand", +] + +[[package]] +name = "authz" +version = "0.1.0" +dependencies = [ + "assert_matches", + "async-trait", + "backoff 0.1.0", + "base64", + "generated_types", + "http", + "iox_time", + "metric", + "observability_deps", + "parking_lot", + "paste", + "snafu 0.8.0", + "test_helpers_end_to_end", + "tokio", + "tonic", + "workspace-hack", +] + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "axum" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf" +dependencies = [ + "async-trait", + "axum-core", + "bitflags 1.3.2", + "bytes", + "futures-util", + "http", + "http-body", + "hyper", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "mime", + "rustversion", + "tower-layer", + "tower-service", +] + +[[package]] +name = "backoff" +version = "0.1.0" +dependencies = [ + "observability_deps", + "rand", + "snafu 0.8.0", + "tokio", + "workspace-hack", +] + +[[package]] +name = "backoff" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" +dependencies = [ + "getrandom", + "instant", + "rand", +] + +[[package]] +name = "backtrace" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" +dependencies = [ + "serde", +] + +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + +[[package]] +name = "blake3" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0231f06152bf547e9c2b5194f247cd97aacf6dcd8b15d8e5ec0663f64580da87" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", +] + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "brotli" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "516074a47ef4bce09577a3b379392300159ce5b1ba2e501ff1c819950066100f" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "2.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e2e4afe60d7dd600fdd3de8d0f08c2b7ec039712e3b6137ff98b7004e82de4f" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + +[[package]] +name = "bstr" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c48f0051a4b4c5e0b6d365cd04af53aeaa209e3cc15ec2cdb69e73cc87fbd0dc" +dependencies = [ + "memchr", + "regex-automata 0.4.5", + "serde", +] + +[[package]] +name = "bumpalo" +version = "3.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" + +[[package]] +name = "bytecount" +version = "0.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1e5f035d16fc623ae5f74981db80a439803888314e3a555fd6f04acd51a3205" + +[[package]] +name = "bytemuck" +version = "1.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea31d69bda4949c1c1562c1e6f042a1caefac98cdc8a298260a2ff41c1e2d42b" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "bytes" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" + +[[package]] +name = "bzip2" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.11+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + +[[package]] +name = "cache_system" +version = "0.1.0" +dependencies = [ + "async-trait", + "backoff 0.1.0", + "criterion", + "futures", + "iox_time", + "metric", + "observability_deps", + "ouroboros", + "parking_lot", + "pdatastructs", + "proptest", + "rand", + "test_helpers", + "tokio", + "tokio-util", + "trace", + "tracker", + "workspace-hack", +] + +[[package]] +name = "camino" +version = "1.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c59e92b5a388f549b863a7bea62612c09f24c8393560709a54558a9abdfb3b9c" +dependencies = [ + "serde", +] + +[[package]] +name = "cargo-platform" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ceed8ef69d8518a5dda55c07425450b58a4e1946f4951eab6d7191ee86c2443d" +dependencies = [ + "serde", +] + +[[package]] +name = "cargo_metadata" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4acbb09d9ee8e23699b9634375c72795d095bf268439da88562cf9b501f181fa" +dependencies = [ + "camino", + "cargo-platform", + "semver", + "serde", + "serde_json", +] + +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "catalog_cache" +version = "0.1.0" +dependencies = [ + "bytes", + "dashmap", + "futures", + "hyper", + "reqwest", + "snafu 0.8.0", + "tokio", + "tokio-util", + "url", + "workspace-hack", +] + +[[package]] +name = "cc" +version = "1.0.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +dependencies = [ + "jobserver", + "libc", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "chrono" +version = "0.4.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f13690e35a5e4ace198e7beea2895d29f3a9cc55015fcebe6336bd2010af9eb" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-targets 0.52.0", +] + +[[package]] +name = "chrono-tz" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91d7b79e99bfaa0d47da0687c43aa3b7381938a62ad3a6498599039321f660b7" +dependencies = [ + "chrono", + "chrono-tz-build", + "phf", +] + +[[package]] +name = "chrono-tz-build" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "433e39f13c9a060046954e0592a8d0a4bcb1040125cbf91cb8ee58964cfb350f" +dependencies = [ + "parse-zoneinfo", + "phf", + "phf_codegen", +] + +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + +[[package]] +name = "clap" +version = "4.4.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e578d6ec4194633722ccf9544794b71b1385c3c027efe0c55db226fc880865c" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_blocks" +version = "0.1.0" +dependencies = [ + "clap", + "ed25519-dalek", + "futures", + "http", + "humantime", + "iox_catalog", + "iox_time", + "itertools 0.12.1", + "metric", + "non-empty-string", + "object_store", + "observability_deps", + "parquet_cache", + "snafu 0.8.0", + "sysinfo", + "tempfile", + "test_helpers", + "trace_exporters", + "trogging", + "url", + "uuid", + "workspace-hack", +] + +[[package]] +name = "clap_builder" +version = "4.4.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4df4df40ec50c46000231c914968278b1eb05098cf8f1b3a518a95030e71d1c7" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf9804afaaf59a91e75b022a30fb7229a7901f60c755489cc61c9b423b836442" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "clap_lex" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" + +[[package]] +name = "client_util" +version = "0.1.0" +dependencies = [ + "http", + "mockito", + "reqwest", + "thiserror", + "tokio", + "tonic", + "tower", + "workspace-hack", +] + +[[package]] +name = "colorchoice" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" + +[[package]] +name = "comfy-table" +version = "7.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c64043d6c7b7a4c58e39e7efccfdea7b93d885a795d0c054a69dbbf4dd52686" +dependencies = [ + "strum", + "strum_macros", + "unicode-width", +] + +[[package]] +name = "concurrent-queue" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d16048cd947b08fa32c24458a22f5dc5e835264f689f4f5653210c69fd107363" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "console" +version = "0.15.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + +[[package]] +name = "const-random" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aaf16c9c2c612020bcfd042e170f6e32de9b9d75adb5277cdbbd2e2c8c8299a" +dependencies = [ + "const-random-macro", +] + +[[package]] +name = "const-random-macro" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" +dependencies = [ + "getrandom", + "once_cell", + "tiny-keccak", +] + +[[package]] +name = "constant_time_eq" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" + +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + +[[package]] +name = "cpp_demangle" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e8227005286ec39567949b33df9896bcadfa6051bccca2488129f108ca23119" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "cpufeatures" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +dependencies = [ + "libc", +] + +[[package]] +name = "crc" +version = "3.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86ec7a15cbe22e59248fc7eadb1907dab5ba09372595da4d73dd805ed4417dfe" +dependencies = [ + "crc-catalog", +] + +[[package]] +name = "crc-catalog" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" + +[[package]] +name = "crc32fast" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "futures", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "tokio", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + +[[package]] +name = "croaring" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7266f0a7275b00ce4c4f4753e8c31afdefe93828101ece83a06e2ddab1dd1010" +dependencies = [ + "byteorder", + "croaring-sys", +] + +[[package]] +name = "croaring-sys" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e47112498c394a7067949ebc07ef429b7384a413cf0efcf675846a47bcd307fb" +dependencies = [ + "cc", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "176dc175b78f56c0f321911d9c8eb2b77a78a4860b9c19db83835fea1a46649b" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-queue" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "csv" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + +[[package]] +name = "curve25519-dalek" +version = "4.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a677b8922c94e01bdbb12126b0bc852f00447528dee1782229af9c720c3f348" +dependencies = [ + "cfg-if", + "cpufeatures", + "curve25519-dalek-derive", + "digest", + "fiat-crypto", + "platforms", + "rustc_version", + "subtle", + "zeroize", +] + +[[package]] +name = "curve25519-dalek-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "darling" +version = "0.20.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc5d6b04b3fd0ba9926f945895de7d806260a2d7431ba82e7edaecb043c4c6b8" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04e48a959bcd5c761246f5d090ebc2fbf7b9cd527a492b07a67510c108f1e7e3" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.48", +] + +[[package]] +name = "darling_macro" +version = "0.20.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1545d67a2149e1d93b7e5c7752dce5a7426eb5d1357ddcfd89336b94444f77" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown 0.14.3", + "lock_api", + "once_cell", + "parking_lot_core", +] + +[[package]] +name = "data_types" +version = "0.1.0" +dependencies = [ + "arrow-buffer", + "assert_matches", + "bytes", + "chrono", + "croaring", + "generated_types", + "hex", + "influxdb-line-protocol", + "iox_time", + "murmur3", + "observability_deps", + "once_cell", + "ordered-float 4.2.0", + "paste", + "percent-encoding", + "proptest", + "prost", + "schema", + "serde_json", + "sha2", + "siphasher 1.0.0", + "snafu 0.8.0", + "sqlx", + "test_helpers", + "thiserror", + "uuid", + "workspace-hack", +] + +[[package]] +name = "datafusion" +version = "34.0.0" +source = "git+https://github.com/apache/arrow-datafusion.git?rev=0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7#0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7" +dependencies = [ + "ahash", + "arrow", + "arrow-array", + "arrow-ipc", + "arrow-schema", + "async-compression", + "async-trait", + "bytes", + "bzip2", + "chrono", + "dashmap", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-optimizer", + "datafusion-physical-expr", + "datafusion-physical-plan", + "datafusion-sql", + "flate2", + "futures", + "glob", + "half", + "hashbrown 0.14.3", + "indexmap 2.2.2", + "itertools 0.12.1", + "log", + "num_cpus", + "object_store", + "parking_lot", + "parquet", + "pin-project-lite", + "rand", + "sqlparser", + "tempfile", + "tokio", + "tokio-util", + "url", + "uuid", + "xz2", + "zstd", +] + +[[package]] +name = "datafusion-common" +version = "34.0.0" +source = "git+https://github.com/apache/arrow-datafusion.git?rev=0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7#0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7" +dependencies = [ + "ahash", + "arrow", + "arrow-array", + "arrow-buffer", + "arrow-schema", + "chrono", + "half", + "libc", + "num_cpus", + "object_store", + "parquet", + "sqlparser", +] + +[[package]] +name = "datafusion-execution" +version = "34.0.0" +source = "git+https://github.com/apache/arrow-datafusion.git?rev=0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7#0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7" +dependencies = [ + "arrow", + "chrono", + "dashmap", + "datafusion-common", + "datafusion-expr", + "futures", + "hashbrown 0.14.3", + "log", + "object_store", + "parking_lot", + "rand", + "tempfile", + "url", +] + +[[package]] +name = "datafusion-expr" +version = "34.0.0" +source = "git+https://github.com/apache/arrow-datafusion.git?rev=0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7#0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7" +dependencies = [ + "ahash", + "arrow", + "arrow-array", + "datafusion-common", + "paste", + "sqlparser", + "strum", + "strum_macros", +] + +[[package]] +name = "datafusion-optimizer" +version = "34.0.0" +source = "git+https://github.com/apache/arrow-datafusion.git?rev=0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7#0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7" +dependencies = [ + "arrow", + "async-trait", + "chrono", + "datafusion-common", + "datafusion-expr", + "datafusion-physical-expr", + "hashbrown 0.14.3", + "itertools 0.12.1", + "log", + "regex-syntax 0.8.2", +] + +[[package]] +name = "datafusion-physical-expr" +version = "34.0.0" +source = "git+https://github.com/apache/arrow-datafusion.git?rev=0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7#0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7" +dependencies = [ + "ahash", + "arrow", + "arrow-array", + "arrow-buffer", + "arrow-ord", + "arrow-schema", + "base64", + "blake2", + "blake3", + "chrono", + "datafusion-common", + "datafusion-expr", + "half", + "hashbrown 0.14.3", + "hex", + "indexmap 2.2.2", + "itertools 0.12.1", + "log", + "md-5", + "paste", + "petgraph", + "rand", + "regex", + "sha2", + "unicode-segmentation", + "uuid", +] + +[[package]] +name = "datafusion-physical-plan" +version = "34.0.0" +source = "git+https://github.com/apache/arrow-datafusion.git?rev=0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7#0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7" +dependencies = [ + "ahash", + "arrow", + "arrow-array", + "arrow-buffer", + "arrow-schema", + "async-trait", + "chrono", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-expr", + "futures", + "half", + "hashbrown 0.14.3", + "indexmap 2.2.2", + "itertools 0.12.1", + "log", + "once_cell", + "parking_lot", + "pin-project-lite", + "rand", + "tokio", + "uuid", +] + +[[package]] +name = "datafusion-proto" +version = "34.0.0" +source = "git+https://github.com/apache/arrow-datafusion.git?rev=0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7#0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7" +dependencies = [ + "arrow", + "chrono", + "datafusion", + "datafusion-common", + "datafusion-expr", + "object_store", + "prost", +] + +[[package]] +name = "datafusion-sql" +version = "34.0.0" +source = "git+https://github.com/apache/arrow-datafusion.git?rev=0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7#0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7" +dependencies = [ + "arrow", + "arrow-schema", + "datafusion-common", + "datafusion-expr", + "log", + "sqlparser", +] + +[[package]] +name = "datafusion_util" +version = "0.1.0" +dependencies = [ + "async-trait", + "datafusion", + "futures", + "object_store", + "observability_deps", + "pin-project", + "schema", + "tokio", + "tokio-stream", + "url", + "workspace-hack", +] + +[[package]] +name = "debugid" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef552e6f588e446098f6ba40d89ac146c8c7b64aade83c051ee00bb5d2bc18d" +dependencies = [ + "uuid", +] + +[[package]] +name = "delegate" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "082a24a9967533dc5d743c602157637116fc1b52806d694a5a45e6f32567fcdd" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "der" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c" +dependencies = [ + "const-oid", + "pem-rfc7468", + "zeroize", +] + +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "diff" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" + +[[package]] +name = "difflib" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6184e33543162437515c2e2b48714794e37845ec9851711914eec9d308f6ebe8" + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "const-oid", + "crypto-common", + "subtle", +] + +[[package]] +name = "dml" +version = "0.1.0" +dependencies = [ + "arrow_util", + "data_types", + "hashbrown 0.14.3", + "iox_time", + "mutable_batch", + "schema", + "trace", + "workspace-hack", +] + +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + +[[package]] +name = "dotenvy" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" + +[[package]] +name = "dyn-clone" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "545b22097d44f8a9581187cdf93de7a71e4722bf51200cfaba810865b49a495d" + +[[package]] +name = "ed25519" +version = "2.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "115531babc129696a58c64a4fef0a8bf9e9698629fb97e9e40767d235cfbcd53" +dependencies = [ + "pkcs8", + "signature", +] + +[[package]] +name = "ed25519-dalek" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a3daa8e81a3963a60642bcc1f90a670680bd4a77535faa384e9d1c79d620871" +dependencies = [ + "curve25519-dalek", + "ed25519", + "serde", + "sha2", + "subtle", + "zeroize", +] + +[[package]] +name = "either" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +dependencies = [ + "serde", +] + +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + +[[package]] +name = "encoding_rs" +version = "0.8.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7268b386296a025e474d5140678f75d6de9493ae55a5d709eeb9dd08149945e1" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + +[[package]] +name = "errno" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "error-chain" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d2f06b9cac1506ece98fe3231e3cc9c4410ec3d5b1f24ae1c8946f0742cdefc" +dependencies = [ + "version_check", +] + +[[package]] +name = "etcetera" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943" +dependencies = [ + "cfg-if", + "home", + "windows-sys 0.48.0", +] + +[[package]] +name = "event-listener" +version = "2.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" + +[[package]] +name = "event-listener" +version = "4.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b215c49b2b248c855fb73579eb1f4f26c38ffdc12973e20e07b91d78d5646e" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3" +dependencies = [ + "event-listener 4.0.3", + "pin-project-lite", +] + +[[package]] +name = "executor" +version = "0.1.0" +dependencies = [ + "futures", + "libc", + "metric", + "observability_deps", + "once_cell", + "parking_lot", + "pin-project", + "snafu 0.8.0", + "tokio", + "tokio-util", + "tokio_metrics_bridge", + "tokio_watchdog", + "workspace-hack", +] + +[[package]] +name = "fastrand" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" + +[[package]] +name = "fiat-crypto" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1676f435fc1dadde4d03e43f5d62b259e1ce5f40bd4ffb21db2b42ebe59c1382" + +[[package]] +name = "filetime" +version = "0.2.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ee447700ac8aa0b2f2bd7bc4462ad686ba06baa6727ac149a2d6277f0d240fd" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "windows-sys 0.52.0", +] + +[[package]] +name = "findshlibs" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40b9e59cd0f7e0806cca4be089683ecb6434e602038df21fe6bf6711b2f07f64" +dependencies = [ + "cc", + "lazy_static", + "libc", + "winapi", +] + +[[package]] +name = "finl_unicode" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fcfdc7a0362c9f4444381a9e697c79d435fe65b52a37466fc2c1184cee9edc6" + +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + +[[package]] +name = "flatbuffers" +version = "23.5.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dac53e22462d78c16d64a1cd22371b54cc3fe94aa15e7886a2fa6e5d1ab8640" +dependencies = [ + "bitflags 1.3.2", + "rustc_version", +] + +[[package]] +name = "flate2" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "flightsql" +version = "0.1.0" +dependencies = [ + "arrow", + "arrow-flight", + "arrow_util", + "bytes", + "datafusion", + "iox_query", + "observability_deps", + "once_cell", + "prost", + "snafu 0.8.0", + "workspace-hack", +] + +[[package]] +name = "flume" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" +dependencies = [ + "futures-core", + "futures-sink", + "spin 0.9.8", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "fsevent-sys" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76ee7a02da4d231650c7cea31349b889be2f45ddb3ef3032d2ec8185f6313fd2" +dependencies = [ + "libc", +] + +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" + +[[package]] +name = "futures-executor" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-intrusive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" +dependencies = [ + "futures-core", + "lock_api", + "parking_lot", +] + +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "futures-sink" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" + +[[package]] +name = "futures-task" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" + +[[package]] +name = "futures-util" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "generated_types" +version = "0.1.0" +dependencies = [ + "bytes", + "observability_deps", + "pbjson", + "pbjson-build", + "pbjson-types", + "prost", + "prost-build", + "serde", + "tonic", + "tonic-build", + "uuid", + "workspace-hack", +] + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "gimli" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" + +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + +[[package]] +name = "grpc-binary-logger" +version = "0.1.0" +dependencies = [ + "assert_matches", + "base64", + "byteorder", + "bytes", + "futures", + "grpc-binary-logger-proto", + "grpc-binary-logger-test-proto", + "http", + "http-body", + "hyper", + "pin-project", + "prost", + "prost-build", + "tokio", + "tokio-stream", + "tonic", + "tonic-build", + "tower", + "workspace-hack", +] + +[[package]] +name = "grpc-binary-logger-proto" +version = "0.1.0" +dependencies = [ + "prost", + "prost-build", + "prost-types", + "tonic", + "tonic-build", + "workspace-hack", +] + +[[package]] +name = "grpc-binary-logger-test-proto" +version = "0.1.0" +dependencies = [ + "prost", + "prost-build", + "tonic", + "tonic-build", + "workspace-hack", +] + +[[package]] +name = "h2" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb2c4422095b67ee78da96fbb51a4cc413b3b25883c7717ff7ca1ab31022c9c9" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http", + "indexmap 2.2.2", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "half" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" +dependencies = [ + "cfg-if", + "crunchy", + "num-traits", +] + +[[package]] +name = "handlebars" +version = "5.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab283476b99e66691dee3f1640fea91487a8d81f50fb5ecc75538f8f8879a1e4" +dependencies = [ + "log", + "pest", + "pest_derive", + "serde", + "serde_json", + "thiserror", +] + +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + +[[package]] +name = "hashbrown" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" +dependencies = [ + "ahash", + "allocator-api2", +] + +[[package]] +name = "hashlink" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" +dependencies = [ + "hashbrown 0.14.3", +] + +[[package]] +name = "heappy" +version = "0.1.0" +source = "git+https://github.com/mkmik/heappy?rev=01a1f88e1b404c5894f89eb1a57f813f713d7ad1#01a1f88e1b404c5894f89eb1a57f813f713d7ad1" +dependencies = [ + "backtrace", + "bytes", + "lazy_static", + "libc", + "pprof", + "spin 0.9.8", + "thiserror", + "tikv-jemalloc-sys", +] + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +dependencies = [ + "unicode-segmentation", +] + +[[package]] +name = "hermit-abi" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0c62115964e08cb8039170eb33c1d0e2388a256930279edca206fff675f82c3" + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "hkdf" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" +dependencies = [ + "hmac", +] + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "http" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" +dependencies = [ + "bytes", + "http", + "pin-project-lite", +] + +[[package]] +name = "http-range-header" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "add0ab9360ddbd88cfeb3bd9574a1d85cfdfa14db10b3e21d3700dbc4328758f" + +[[package]] +name = "httparse" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + +[[package]] +name = "hyper" +version = "0.14.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf96e135eb83a2a8ddf766e426a841d8ddd7449d5f00d34ea02b41d2f19eef80" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" +dependencies = [ + "futures-util", + "http", + "hyper", + "log", + "rustls", + "rustls-native-certs", + "tokio", + "tokio-rustls", +] + +[[package]] +name = "hyper-timeout" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" +dependencies = [ + "hyper", + "pin-project-lite", + "tokio", + "tokio-io-timeout", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "import_export" +version = "0.1.0" +dependencies = [ + "bytes", + "data_types", + "futures-util", + "generated_types", + "influxdb_iox_client", + "iox_catalog", + "object_store", + "observability_deps", + "parquet_file", + "schema", + "serde_json", + "thiserror", + "tokio", + "tokio-util", + "workspace-hack", +] + +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", +] + +[[package]] +name = "indexmap" +version = "2.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "824b2ae422412366ba479e8111fd301f7b5faece8149317bb81925979a53f520" +dependencies = [ + "equivalent", + "hashbrown 0.14.3", +] + +[[package]] +name = "inferno" +version = "0.11.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "321f0f839cd44a4686e9504b0a62b4d69a50b62072144c71c68f5873c167b8d9" +dependencies = [ + "ahash", + "indexmap 2.2.2", + "is-terminal", + "itoa", + "log", + "num-format", + "once_cell", + "quick-xml 0.26.0", + "rgb", + "str_stack", +] + +[[package]] +name = "influxdb-line-protocol" +version = "1.0.0" +dependencies = [ + "bytes", + "log", + "nom", + "smallvec", + "snafu 0.8.0", + "test_helpers", +] + +[[package]] +name = "influxdb2_client" +version = "0.1.0" +dependencies = [ + "bytes", + "futures", + "mockito", + "once_cell", + "parking_lot", + "reqwest", + "serde", + "serde_json", + "snafu 0.8.0", + "test_helpers", + "tokio", + "url", + "uuid", +] + +[[package]] +name = "influxdb_influxql_parser" +version = "0.1.0" +dependencies = [ + "assert_matches", + "chrono", + "chrono-tz", + "insta", + "nom", + "num-integer", + "num-traits", + "once_cell", + "paste", + "test_helpers", + "workspace-hack", +] + +[[package]] +name = "influxdb_iox_client" +version = "0.1.0" +dependencies = [ + "arrow", + "arrow-flight", + "arrow_util", + "bytes", + "client_util", + "comfy-table", + "futures-util", + "generated_types", + "influxdb-line-protocol", + "insta", + "iox_query_params", + "prost", + "rand", + "reqwest", + "schema", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", + "tonic", +] + +[[package]] +name = "influxdb_storage_client" +version = "0.1.0" +dependencies = [ + "client_util", + "futures-util", + "generated_types", + "observability_deps", + "prost", + "tonic", + "workspace-hack", +] + +[[package]] +name = "influxdb_tsm" +version = "0.1.0" +dependencies = [ + "flate2", + "hex", + "integer-encoding 4.0.0", + "observability_deps", + "rand", + "snafu 0.7.5", + "snap", + "test_helpers", + "workspace-hack", +] + +[[package]] +name = "influxrpc_parser" +version = "0.1.0" +dependencies = [ + "generated_types", + "snafu 0.8.0", + "sqlparser", + "workspace-hack", +] + +[[package]] +name = "ingester_query_grpc" +version = "0.1.0" +dependencies = [ + "arrow", + "base64", + "bytes", + "data_types", + "datafusion", + "datafusion-proto", + "flatbuffers", + "pbjson", + "pbjson-build", + "predicate", + "prost", + "prost-build", + "query_functions", + "serde", + "snafu 0.8.0", + "tonic", + "tonic-build", + "workspace-hack", +] + +[[package]] +name = "inotify" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8069d3ec154eb856955c1c0fbffefbf5f3c40a104ec912d4797314c1801abff" +dependencies = [ + "bitflags 1.3.2", + "inotify-sys", + "libc", +] + +[[package]] +name = "inotify-sys" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb" +dependencies = [ + "libc", +] + +[[package]] +name = "insta" +version = "1.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d64600be34b2fcfc267740a243fa7744441bb4947a619ac4e5bb6507f35fbfc" +dependencies = [ + "console", + "lazy_static", + "linked-hash-map", + "serde", + "similar", + "yaml-rust", +] + +[[package]] +name = "instant" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "integer-encoding" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" + +[[package]] +name = "integer-encoding" +version = "4.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "924df4f0e24e2e7f9cdd90babb0b96f93b20f3ecfa949ea9e6613756b8c8e1bf" + +[[package]] +name = "iox_catalog" +version = "0.1.0" +dependencies = [ + "assert_matches", + "async-trait", + "backoff 0.1.0", + "catalog_cache", + "data_types", + "dotenvy", + "futures", + "generated_types", + "iox_time", + "log", + "metric", + "mutable_batch", + "mutable_batch_lp", + "observability_deps", + "once_cell", + "parking_lot", + "paste", + "pretty_assertions", + "proptest", + "rand", + "serde", + "siphasher 1.0.0", + "snafu 0.8.0", + "sqlx", + "sqlx-hotswap-pool", + "tempfile", + "test_helpers", + "thiserror", + "tokio", + "tonic", + "trace_http", + "uuid", + "workspace-hack", +] + +[[package]] +name = "iox_data_generator" +version = "0.1.0" +dependencies = [ + "bytes", + "chrono", + "clap", + "criterion", + "datafusion_util", + "futures", + "handlebars", + "humantime", + "influxdb2_client", + "itertools 0.12.1", + "mutable_batch", + "mutable_batch_lp", + "parquet_file", + "rand", + "regex", + "schema", + "serde", + "serde_json", + "snafu 0.8.0", + "test_helpers", + "tokio", + "toml", + "tracing", + "tracing-subscriber", + "uuid", +] + +[[package]] +name = "iox_query" +version = "0.1.0" +dependencies = [ + "arrow", + "arrow_util", + "assert_matches", + "async-trait", + "chrono", + "data_types", + "datafusion", + "datafusion_util", + "executor", + "futures", + "hashbrown 0.14.3", + "indexmap 2.2.2", + "insta", + "iox_time", + "itertools 0.12.1", + "metric", + "object_store", + "observability_deps", + "once_cell", + "parking_lot", + "parquet_file", + "predicate", + "query_functions", + "schema", + "serde", + "snafu 0.8.0", + "test_helpers", + "tokio", + "tokio-stream", + "trace", + "tracker", + "uuid", + "workspace-hack", +] + +[[package]] +name = "iox_query_influxql" +version = "0.1.0" +dependencies = [ + "arrow", + "assert_matches", + "chrono", + "chrono-tz", + "datafusion", + "datafusion_util", + "generated_types", + "influxdb_influxql_parser", + "insta", + "iox_query", + "itertools 0.12.1", + "observability_deps", + "once_cell", + "predicate", + "query_functions", + "regex", + "schema", + "serde_json", + "test_helpers", + "thiserror", + "workspace-hack", +] + +[[package]] +name = "iox_query_influxrpc" +version = "0.1.0" +dependencies = [ + "arrow", + "arrow_util", + "data_types", + "datafusion", + "datafusion_util", + "futures", + "hashbrown 0.14.3", + "insta", + "iox_query", + "observability_deps", + "predicate", + "query_functions", + "schema", + "snafu 0.8.0", + "test_helpers", + "tokio", + "workspace-hack", +] + +[[package]] +name = "iox_query_params" +version = "0.1.0" +dependencies = [ + "assert_matches", + "datafusion", + "generated_types", + "observability_deps", + "serde", + "serde_json", + "thiserror", + "workspace-hack", +] + +[[package]] +name = "iox_tests" +version = "0.1.0" +dependencies = [ + "arrow", + "data_types", + "datafusion", + "datafusion_util", + "generated_types", + "iox_catalog", + "iox_query", + "iox_time", + "metric", + "mutable_batch_lp", + "object_store", + "observability_deps", + "parquet_file", + "schema", + "uuid", + "workspace-hack", +] + +[[package]] +name = "iox_time" +version = "0.1.0" +dependencies = [ + "chrono", + "parking_lot", + "tokio", + "workspace-hack", +] + +[[package]] +name = "ioxd_common" +version = "0.1.0" +dependencies = [ + "async-trait", + "authz", + "bytes", + "clap", + "clap_blocks", + "flate2", + "futures", + "generated_types", + "hashbrown 0.14.3", + "heappy", + "http", + "hyper", + "log", + "metric", + "metric_exporters", + "observability_deps", + "parking_lot", + "pprof", + "reqwest", + "serde", + "serde_json", + "serde_urlencoded", + "service_grpc_testing", + "snafu 0.8.0", + "tokio", + "tokio-stream", + "tokio-util", + "tonic", + "tonic-health", + "tonic-reflection", + "tower", + "tower-http", + "tower_trailer", + "trace", + "trace_exporters", + "trace_http", + "workspace-hack", +] + +[[package]] +name = "ioxd_test" +version = "0.1.0" +dependencies = [ + "async-trait", + "clap", + "hyper", + "ioxd_common", + "metric", + "snafu 0.8.0", + "tokio-util", + "trace", + "workspace-hack", +] + +[[package]] +name = "ipnet" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" + +[[package]] +name = "is-terminal" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bad00257d07be169d870ab665980b06cdb366d792ad690bf2e76876dc503455" +dependencies = [ + "hermit-abi", + "rustix", + "windows-sys 0.52.0", +] + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" + +[[package]] +name = "jobserver" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab46a6e9526ddef3ae7f787c06f0f2600639ba80ea3eade3d8e670a2230f51d6" +dependencies = [ + "libc", +] + +[[package]] +name = "js-sys" +version = "0.3.68" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "406cda4b368d531c842222cf9d2600a9a4acce8d29423695379c6868a143a9ee" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "json-patch" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55ff1e1486799e3f64129f8ccad108b38290df9cd7015cd31bed17239f0789d6" +dependencies = [ + "serde", + "serde_json", + "thiserror", + "treediff", +] + +[[package]] +name = "jsonpath-rust" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06cc127b7c3d270be504572364f9569761a180b981919dd0d87693a7f5fb7829" +dependencies = [ + "pest", + "pest_derive", + "regex", + "serde_json", + "thiserror", +] + +[[package]] +name = "k8s-openapi" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edc3606fd16aca7989db2f84bb25684d0270c6d6fa1dbcd0025af7b4130523a6" +dependencies = [ + "base64", + "bytes", + "chrono", + "schemars", + "serde", + "serde-value", + "serde_json", +] + +[[package]] +name = "kqueue" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7447f1ca1b7b563588a205fe93dea8df60fd981423a768bc1c0ded35ed147d0c" +dependencies = [ + "kqueue-sys", + "libc", +] + +[[package]] +name = "kqueue-sys" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed9625ffda8729b85e45cf04090035ac368927b8cebc34898e7c120f52e4838b" +dependencies = [ + "bitflags 1.3.2", + "libc", +] + +[[package]] +name = "kube" +version = "0.87.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3499c8d60c763246c7a213f51caac1e9033f46026904cb89bc8951ae8601f26e" +dependencies = [ + "k8s-openapi", + "kube-client", + "kube-core", + "kube-derive", + "kube-runtime", +] + +[[package]] +name = "kube-client" +version = "0.87.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "033450dfa0762130565890dadf2f8835faedf749376ca13345bcd8ecd6b5f29f" +dependencies = [ + "base64", + "bytes", + "chrono", + "either", + "futures", + "home", + "http", + "http-body", + "hyper", + "hyper-rustls", + "hyper-timeout", + "jsonpath-rust", + "k8s-openapi", + "kube-core", + "pem", + "pin-project", + "rustls", + "rustls-pemfile", + "secrecy", + "serde", + "serde_json", + "serde_yaml", + "thiserror", + "tokio", + "tokio-util", + "tower", + "tower-http", + "tracing", +] + +[[package]] +name = "kube-core" +version = "0.87.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5bba93d054786eba7994d03ce522f368ef7d48c88a1826faa28478d85fb63ae" +dependencies = [ + "chrono", + "form_urlencoded", + "http", + "json-patch", + "k8s-openapi", + "once_cell", + "schemars", + "serde", + "serde_json", + "thiserror", +] + +[[package]] +name = "kube-derive" +version = "0.87.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e98dd5e5767c7b894c1f0e41fd628b145f808e981feb8b08ed66455d47f1a4" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "serde_json", + "syn 2.0.48", +] + +[[package]] +name = "kube-runtime" +version = "0.87.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d8893eb18fbf6bb6c80ef6ee7dd11ec32b1dc3c034c988ac1b3a84d46a230ae" +dependencies = [ + "ahash", + "async-trait", + "backoff 0.4.0", + "derivative", + "futures", + "hashbrown 0.14.3", + "json-patch", + "k8s-openapi", + "kube-client", + "parking_lot", + "pin-project", + "serde", + "serde_json", + "smallvec", + "thiserror", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "kube_test" +version = "0.1.0" +dependencies = [ + "http", + "hyper", + "k8s-openapi", + "kube-core", + "rand", + "serde", + "serde_json", + "serde_yaml", + "tower", + "workspace-hack", +] + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +dependencies = [ + "spin 0.5.2", +] + +[[package]] +name = "lexical-core" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" +dependencies = [ + "lexical-parse-float", + "lexical-parse-integer", + "lexical-util", + "lexical-write-float", + "lexical-write-integer", +] + +[[package]] +name = "lexical-parse-float" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" +dependencies = [ + "lexical-parse-integer", + "lexical-util", + "static_assertions", +] + +[[package]] +name = "lexical-parse-integer" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" +dependencies = [ + "lexical-util", + "static_assertions", +] + +[[package]] +name = "lexical-util" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" +dependencies = [ + "static_assertions", +] + +[[package]] +name = "lexical-write-float" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" +dependencies = [ + "lexical-util", + "lexical-write-integer", + "static_assertions", +] + +[[package]] +name = "lexical-write-integer" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" +dependencies = [ + "lexical-util", + "static_assertions", +] + +[[package]] +name = "libc" +version = "0.2.153" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + +[[package]] +name = "libsqlite3-sys" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf4e226dcd58b4be396f7bd3c20da8fdee2911400705297ba7d2d7cc2c30f716" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "linked-hash-map" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" + +[[package]] +name = "linux-raw-sys" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" + +[[package]] +name = "lock_api" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" + +[[package]] +name = "logfmt" +version = "0.1.0" +dependencies = [ + "observability_deps", + "once_cell", + "parking_lot", + "regex", + "tracing-subscriber", + "workspace-hack", +] + +[[package]] +name = "lz4_flex" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "912b45c753ff5f7f5208307e8ace7d2a2e30d024e26d3509f3dce546c044ce15" +dependencies = [ + "twox-hash", +] + +[[package]] +name = "lzma-sys" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fda04ab3764e6cde78b9974eec4f779acaba7c4e84b36eca3cf77c581b85d27" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + +[[package]] +name = "md-5" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +dependencies = [ + "cfg-if", + "digest", +] + +[[package]] +name = "memchr" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" + +[[package]] +name = "memmap2" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe751422e4a8caa417e13c3ea66452215d7d63e19e604f4980461212f3ae1322" +dependencies = [ + "libc", +] + +[[package]] +name = "metric" +version = "0.1.0" +dependencies = [ + "parking_lot", + "workspace-hack", +] + +[[package]] +name = "metric_exporters" +version = "0.1.0" +dependencies = [ + "metric", + "observability_deps", + "prometheus", + "test_helpers", + "workspace-hack", +] + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "miniz_oxide" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" +dependencies = [ + "adler", +] + +[[package]] +name = "mio" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" +dependencies = [ + "libc", + "log", + "wasi", + "windows-sys 0.48.0", +] + +[[package]] +name = "mockito" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8d3038e23466858569c2d30a537f691fa0d53b51626630ae08262943e3bbb8b" +dependencies = [ + "assert-json-diff", + "futures", + "hyper", + "log", + "rand", + "regex", + "serde_json", + "serde_urlencoded", + "similar", + "tokio", +] + +[[package]] +name = "moka" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1911e88d5831f748a4097a43862d129e3c6fca831eecac9b8db6d01d93c9de2" +dependencies = [ + "async-lock", + "async-trait", + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "futures-util", + "once_cell", + "parking_lot", + "quanta", + "rustc_version", + "skeptic", + "smallvec", + "tagptr", + "thiserror", + "triomphe", + "uuid", +] + +[[package]] +name = "mpchash" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdd8199faa645318222f8aeb383fca4216a3f75b144f1e264ac74c0835d871a9" +dependencies = [ + "num-traits", + "rand", + "xxhash-rust", +] + +[[package]] +name = "multimap" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" + +[[package]] +name = "murmur3" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9252111cf132ba0929b6f8e030cac2a24b507f3a4d6db6fb2896f27b354c714b" + +[[package]] +name = "mutable_batch" +version = "0.1.0" +dependencies = [ + "arrow", + "arrow_util", + "assert_matches", + "data_types", + "hashbrown 0.14.3", + "iox_time", + "itertools 0.12.1", + "mutable_batch_lp", + "partition", + "pretty_assertions", + "proptest", + "rand", + "schema", + "snafu 0.8.0", + "workspace-hack", +] + +[[package]] +name = "mutable_batch_lp" +version = "0.1.0" +dependencies = [ + "arrow_util", + "assert_matches", + "criterion", + "hashbrown 0.14.3", + "influxdb-line-protocol", + "itertools 0.12.1", + "mutable_batch", + "schema", + "snafu 0.8.0", + "test_helpers", + "workspace-hack", +] + +[[package]] +name = "mutable_batch_pb" +version = "0.1.0" +dependencies = [ + "arrow_util", + "data_types", + "dml", + "generated_types", + "hashbrown 0.14.3", + "mutable_batch", + "mutable_batch_lp", + "partition", + "schema", + "snafu 0.8.0", + "workspace-hack", +] + +[[package]] +name = "mutable_batch_tests" +version = "0.1.0" +dependencies = [ + "bytes", + "criterion", + "data_types", + "dml", + "flate2", + "generated_types", + "mutable_batch", + "mutable_batch_lp", + "mutable_batch_pb", + "prost", +] + +[[package]] +name = "nix" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b" +dependencies = [ + "bitflags 1.3.2", + "cfg-if", + "libc", +] + +[[package]] +name = "nix" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053" +dependencies = [ + "bitflags 2.4.2", + "cfg-if", + "libc", +] + +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + +[[package]] +name = "non-empty-string" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cf0f4060e345ae505219853da9ca1150564158a648a6aa6a528f0d5794bb33" +dependencies = [ + "delegate", +] + +[[package]] +name = "notify" +version = "6.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6205bd8bb1e454ad2e27422015fb5e4f2bcc7e08fa8f27058670d208324a4d2d" +dependencies = [ + "bitflags 2.4.2", + "crossbeam-channel", + "filetime", + "fsevent-sys", + "inotify", + "kqueue", + "libc", + "log", + "mio", + "walkdir", + "windows-sys 0.48.0", +] + +[[package]] +name = "ntapi" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4" +dependencies = [ + "winapi", +] + +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + +[[package]] +name = "num" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-bigint-dig" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc84195820f291c7697304f3cbdadd1cb7199c0efc917ff5eafd71225c136151" +dependencies = [ + "byteorder", + "lazy_static", + "libm", + "num-integer", + "num-iter", + "num-traits", + "rand", + "smallvec", + "zeroize", +] + +[[package]] +name = "num-complex" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-format" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a652d9771a63711fd3c3deb670acfbe5c30a4072e664d7a3bf5a9e1056ac72c3" +dependencies = [ + "arrayvec", + "itoa", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d869c01cc0c455284163fd0092f1f93835385ccab5a98a0dcc497b2f8bf055a9" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +dependencies = [ + "autocfg", + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "object" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +dependencies = [ + "memchr", +] + +[[package]] +name = "object_store" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2524735495ea1268be33d200e1ee97455096a0846295a21548cd2f3541de7050" +dependencies = [ + "async-trait", + "base64", + "bytes", + "chrono", + "futures", + "humantime", + "hyper", + "itertools 0.11.0", + "parking_lot", + "percent-encoding", + "quick-xml 0.31.0", + "rand", + "reqwest", + "ring", + "rustls-pemfile", + "serde", + "serde_json", + "snafu 0.7.5", + "tokio", + "tracing", + "url", + "walkdir", +] + +[[package]] +name = "object_store_metrics" +version = "0.1.0" +dependencies = [ + "async-trait", + "bytes", + "futures", + "iox_time", + "metric", + "object_store", + "pin-project", + "snafu 0.8.0", + "tokio", + "workspace-hack", +] + +[[package]] +name = "observability_deps" +version = "0.1.0" +dependencies = [ + "tracing", + "workspace-hack", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +dependencies = [ + "parking_lot_core", +] + +[[package]] +name = "oorandom" +version = "11.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" + +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + +[[package]] +name = "ordered-float" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c" +dependencies = [ + "num-traits", +] + +[[package]] +name = "ordered-float" +version = "4.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e" +dependencies = [ + "num-traits", +] + +[[package]] +name = "ouroboros" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97b7be5a8a3462b752f4be3ff2b2bf2f7f1d00834902e46be2a4d68b87b0573c" +dependencies = [ + "aliasable", + "ouroboros_macro", + "static_assertions", +] + +[[package]] +name = "ouroboros_macro" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b645dcde5f119c2c454a92d0dfa271a2a3b205da92e4292a68ead4bdbfde1f33" +dependencies = [ + "heck", + "itertools 0.12.1", + "proc-macro2", + "proc-macro2-diagnostics", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + +[[package]] +name = "panic_logging" +version = "0.1.0" +dependencies = [ + "metric", + "observability_deps", + "test_helpers", + "workspace-hack", +] + +[[package]] +name = "parking" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" + +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.48.5", +] + +[[package]] +name = "parquet" +version = "49.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af88740a842787da39b3d69ce5fbf6fce97d20211d3b299fee0a0da6430c74d4" +dependencies = [ + "ahash", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-ipc", + "arrow-schema", + "arrow-select", + "base64", + "brotli", + "bytes", + "chrono", + "flate2", + "futures", + "hashbrown 0.14.3", + "lz4_flex", + "num", + "num-bigint", + "object_store", + "paste", + "seq-macro", + "snap", + "thrift", + "tokio", + "twox-hash", + "zstd", +] + +[[package]] +name = "parquet_cache" +version = "0.1.0" +dependencies = [ + "ahash", + "arc-swap", + "assert_matches", + "async-channel", + "async-trait", + "backoff 0.1.0", + "bytes", + "chrono", + "data_types", + "fnv", + "futures", + "http", + "hyper", + "iox_catalog", + "iox_tests", + "iox_time", + "k8s-openapi", + "kube", + "kube_test", + "lazy_static", + "moka", + "mpchash", + "notify", + "object_store", + "observability_deps", + "parking_lot", + "parquet_file", + "pin-project", + "rand", + "reqwest", + "schemars", + "serde", + "serde_json", + "tempfile", + "thiserror", + "tokio", + "tokio-stream", + "tokio-util", + "tower", + "url", + "uuid", + "workspace-hack", +] + +[[package]] +name = "parquet_file" +version = "0.1.0" +dependencies = [ + "arrow", + "assert_matches", + "base64", + "bytes", + "data_types", + "datafusion", + "datafusion_util", + "futures", + "generated_types", + "iox_time", + "object_store", + "observability_deps", + "parquet", + "pbjson-types", + "prost", + "rand", + "schema", + "snafu 0.8.0", + "test_helpers", + "thiserror", + "thrift", + "tokio", + "uuid", + "workspace-hack", + "zstd", +] + +[[package]] +name = "parquet_to_line_protocol" +version = "0.1.0" +dependencies = [ + "datafusion", + "datafusion_util", + "futures", + "influxdb-line-protocol", + "mutable_batch_lp", + "num_cpus", + "object_store", + "parquet_file", + "schema", + "snafu 0.8.0", + "tokio", + "workspace-hack", +] + +[[package]] +name = "parse-zoneinfo" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c705f256449c60da65e11ff6626e0c16a0a0b96aaa348de61376b249bc340f41" +dependencies = [ + "regex", +] + +[[package]] +name = "partition" +version = "0.1.0" +dependencies = [ + "arrow", + "assert_matches", + "chrono", + "criterion", + "data_types", + "generated_types", + "hashbrown 0.14.3", + "mutable_batch", + "mutable_batch_lp", + "paste", + "percent-encoding", + "proptest", + "rand", + "schema", + "test_helpers", + "thiserror", + "unicode-segmentation", + "workspace-hack", +] + +[[package]] +name = "paste" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" + +[[package]] +name = "pbjson" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1030c719b0ec2a2d25a5df729d6cff1acf3cc230bf766f4f97833591f7577b90" +dependencies = [ + "base64", + "serde", +] + +[[package]] +name = "pbjson-build" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2580e33f2292d34be285c5bc3dba5259542b083cfad6037b6d70345f24dcb735" +dependencies = [ + "heck", + "itertools 0.11.0", + "prost", + "prost-types", +] + +[[package]] +name = "pbjson-types" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18f596653ba4ac51bdecbb4ef6773bc7f56042dc13927910de1684ad3d32aa12" +dependencies = [ + "bytes", + "chrono", + "pbjson", + "pbjson-build", + "prost", + "prost-build", + "serde", +] + +[[package]] +name = "pdatastructs" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bdcb4943c3c68659690124771ffb2fd93b73900bd0fb47e934f7b8b2e6687fa" +dependencies = [ + "fixedbitset", +] + +[[package]] +name = "pem" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b8fcc794035347fb64beda2d3b462595dd2753e3f268d89c5aae77e8cf2c310" +dependencies = [ + "base64", + "serde", +] + +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + +[[package]] +name = "pest" +version = "2.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "219c0dcc30b6a27553f9cc242972b67f75b60eb0db71f0b5462f38b058c41546" +dependencies = [ + "memchr", + "thiserror", + "ucd-trie", +] + +[[package]] +name = "pest_derive" +version = "2.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e1288dbd7786462961e69bfd4df7848c1e37e8b74303dbdab82c3a9cdd2809" +dependencies = [ + "pest", + "pest_generator", +] + +[[package]] +name = "pest_generator" +version = "2.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1381c29a877c6d34b8c176e734f35d7f7f5b3adaefe940cb4d1bb7af94678e2e" +dependencies = [ + "pest", + "pest_meta", + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "pest_meta" +version = "2.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0934d6907f148c22a3acbda520c7eed243ad7487a30f51f6ce52b58b7077a8a" +dependencies = [ + "once_cell", + "pest", + "sha2", +] + +[[package]] +name = "petgraph" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" +dependencies = [ + "fixedbitset", + "indexmap 2.2.2", +] + +[[package]] +name = "phf" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_codegen" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +dependencies = [ + "phf_generator", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +dependencies = [ + "phf_shared", + "rand", +] + +[[package]] +name = "phf_shared" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +dependencies = [ + "siphasher 0.3.11", +] + +[[package]] +name = "pin-project" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0302c4a0442c456bd56f841aee5c3bfd17967563f6fadc9ceb9f9c23cf3807e0" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "266c042b60c9c76b8d53061e52b2e0d1116abc57cefc8c5cd671619a56ac3690" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der", + "pkcs8", + "spki", +] + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + +[[package]] +name = "pkg-config" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb" + +[[package]] +name = "platforms" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "626dec3cac7cc0e1577a2ec3fc496277ec2baa084bebad95bb6fdbfae235f84c" + +[[package]] +name = "pprof" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef5c97c51bd34c7e742402e216abdeb44d415fbe6ae41d56b114723e953711cb" +dependencies = [ + "backtrace", + "cfg-if", + "findshlibs", + "inferno", + "libc", + "log", + "nix 0.26.4", + "once_cell", + "parking_lot", + "prost", + "prost-build", + "prost-derive", + "protobuf", + "sha2", + "smallvec", + "symbolic-demangle", + "tempfile", + "thiserror", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "predicate" +version = "0.1.0" +dependencies = [ + "arrow", + "chrono", + "data_types", + "datafusion", + "datafusion_util", + "itertools 0.12.1", + "observability_deps", + "query_functions", + "schema", + "snafu 0.8.0", + "sqlparser", + "test_helpers", + "workspace-hack", +] + +[[package]] +name = "predicates" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68b87bfd4605926cdfefc1c3b5f8fe560e3feca9d5552cf68c466d3d8236c7e8" +dependencies = [ + "anstyle", + "difflib", + "predicates-core", +] + +[[package]] +name = "predicates-core" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b794032607612e7abeb4db69adb4e33590fa6cf1149e95fd7cb00e634b92f174" + +[[package]] +name = "predicates-tree" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368ba315fb8c5052ab692e68a0eefec6ec57b23a36959c14496f0b0df2c0cecf" +dependencies = [ + "predicates-core", + "termtree", +] + +[[package]] +name = "pretty_assertions" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af7cee1a6c8a5b9208b3cb1061f10c0cb689087b3d8ce85fb9d2dd7a29b6ba66" +dependencies = [ + "diff", + "yansi 0.5.1", +] + +[[package]] +name = "prettyplease" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a41cf62165e97c7f814d2221421dbb9afcbcdb0a88068e5ea206e19951c2cbb5" +dependencies = [ + "proc-macro2", + "syn 2.0.48", +] + +[[package]] +name = "proc-macro2" +version = "1.0.78" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "proc-macro2-diagnostics" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", + "version_check", + "yansi 1.0.0-rc.1", +] + +[[package]] +name = "prometheus" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "449811d15fbdf5ceb5c1144416066429cf82316e2ec8ce0c1f6f8a02e7bbcf8c" +dependencies = [ + "cfg-if", + "fnv", + "lazy_static", + "memchr", + "parking_lot", + "thiserror", +] + +[[package]] +name = "proptest" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31b476131c3c86cb68032fdc5cb6d5a1045e3e42d96b69fa599fd77701e1f5bf" +dependencies = [ + "bitflags 2.4.2", + "lazy_static", + "num-traits", + "rand", + "rand_chacha", + "rand_xorshift", + "regex-syntax 0.8.2", + "unarray", +] + +[[package]] +name = "prost" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c289cda302b98a28d40c8b3b90498d6e526dd24ac2ecea73e4e491685b94a" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-build" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c55e02e35260070b6f716a2423c2ff1c3bb1642ddca6f99e1f26d06268a0e2d2" +dependencies = [ + "bytes", + "heck", + "itertools 0.11.0", + "log", + "multimap", + "once_cell", + "petgraph", + "prettyplease", + "prost", + "prost-types", + "regex", + "syn 2.0.48", + "tempfile", + "which", +] + +[[package]] +name = "prost-derive" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efb6c9a1dd1def8e2124d17e83a20af56f1570d6c2d2bd9e266ccb768df3840e" +dependencies = [ + "anyhow", + "itertools 0.11.0", + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "prost-types" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "193898f59edcf43c26227dcd4c8427f00d99d61e95dcde58dabd49fa291d470e" +dependencies = [ + "prost", +] + +[[package]] +name = "protobuf" +version = "2.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94" + +[[package]] +name = "pulldown-cmark" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57206b407293d2bcd3af849ce869d52068623f19e1b5ff8e8778e3309439682b" +dependencies = [ + "bitflags 2.4.2", + "memchr", + "unicase", +] + +[[package]] +name = "quanta" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ca0b7bac0b97248c40bb77288fc52029cf1459c0461ea1b05ee32ccf011de2c" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi", + "web-sys", + "winapi", +] + +[[package]] +name = "query_functions" +version = "0.1.0" +dependencies = [ + "arrow", + "chrono", + "datafusion", + "datafusion_util", + "itertools 0.12.1", + "once_cell", + "regex", + "regex-syntax 0.8.2", + "schema", + "snafu 0.8.0", + "tokio", + "workspace-hack", +] + +[[package]] +name = "quick-xml" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f50b1c63b38611e7d4d7f68b82d3ad0cc71a2ad2e7f61fc10f1328d917c93cd" +dependencies = [ + "memchr", +] + +[[package]] +name = "quick-xml" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1004a344b30a54e2ee58d66a71b32d2db2feb0a31f9a2d302bf0536f15de2a33" +dependencies = [ + "memchr", + "serde", +] + +[[package]] +name = "quote" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_xorshift" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f" +dependencies = [ + "rand_core", +] + +[[package]] +name = "raw-cpuid" +version = "11.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d86a7c4638d42c44551f4791a20e687dbb4c3de1f33c43dd71e355cd429def1" +dependencies = [ + "bitflags 2.4.2", +] + +[[package]] +name = "rayon" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7237101a77a10773db45d62004a272517633fbcc3df19d96455ede1122e051" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "redox_syscall" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags 1.3.2", +] + +[[package]] +name = "regex" +version = "1.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata 0.4.5", + "regex-syntax 0.8.2", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", +] + +[[package]] +name = "regex-automata" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax 0.8.2", +] + +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + +[[package]] +name = "regex-syntax" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" + +[[package]] +name = "reqwest" +version = "0.11.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6920094eb85afde5e4a138be3f2de8bbdf28000f0029e72c45025a56b042251" +dependencies = [ + "base64", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "hyper", + "hyper-rustls", + "ipnet", + "js-sys", + "log", + "mime", + "once_cell", + "percent-encoding", + "pin-project-lite", + "rustls", + "rustls-native-certs", + "rustls-pemfile", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "system-configuration", + "tokio", + "tokio-rustls", + "tokio-util", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams", + "web-sys", + "webpki-roots", + "winreg", +] + +[[package]] +name = "rgb" +version = "0.8.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05aaa8004b64fd573fc9d002f4e632d51ad4f026c2b5ba95fcb6c2f32c2c47d8" +dependencies = [ + "bytemuck", +] + +[[package]] +name = "ring" +version = "0.17.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" +dependencies = [ + "cc", + "getrandom", + "libc", + "spin 0.9.8", + "untrusted", + "windows-sys 0.48.0", +] + +[[package]] +name = "rsa" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d0e5124fcb30e76a7e79bfee683a2746db83784b86289f6251b54b7950a0dfc" +dependencies = [ + "const-oid", + "digest", + "num-bigint-dig", + "num-integer", + "num-traits", + "pkcs1", + "pkcs8", + "rand_core", + "signature", + "spki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" + +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + +[[package]] +name = "rustix" +version = "0.38.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" +dependencies = [ + "bitflags 2.4.2", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustls" +version = "0.21.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" +dependencies = [ + "log", + "ring", + "rustls-webpki", + "sct", +] + +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pemfile" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" +dependencies = [ + "base64", +] + +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring", + "untrusted", +] + +[[package]] +name = "rustversion" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" + +[[package]] +name = "ryu" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "schannel" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "schema" +version = "0.1.0" +dependencies = [ + "arrow", + "hashbrown 0.14.3", + "indexmap 2.2.2", + "observability_deps", + "once_cell", + "snafu 0.8.0", + "workspace-hack", +] + +[[package]] +name = "schemars" +version = "0.8.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45a28f4c49489add4ce10783f7911893516f15afe45d015608d41faca6bc4d29" +dependencies = [ + "dyn-clone", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "0.8.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c767fd6fa65d9ccf9cf026122c1b555f2ef9a4f0cea69da4d7dbc3e258d30967" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn 1.0.109", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "sct" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" +dependencies = [ + "ring", + "untrusted", +] + +[[package]] +name = "secrecy" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9bd1c54ea06cfd2f6b63219704de0b9b4f72dcc2b8fdef820be6cd799780e91e" +dependencies = [ + "serde", + "zeroize", +] + +[[package]] +name = "security-framework" +version = "2.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e932934257d3b408ed8f30db49d85ea163bfe74961f017f405b025af298f0c7a" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "semver" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97ed7a9823b74f99c7742f5336af7be5ecd3eeafcb1507d1fa93347b1d589b0" +dependencies = [ + "serde", +] + +[[package]] +name = "seq-macro" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" + +[[package]] +name = "serde" +version = "1.0.196" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde-value" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3a1a3341211875ef120e117ea7fd5228530ae7e7036a779fdc9117be6b3282c" +dependencies = [ + "ordered-float 2.10.1", + "serde", +] + +[[package]] +name = "serde_derive" +version = "1.0.196" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "serde_derive_internals" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85bf8229e7920a9f636479437026331ce11aa132b4dde37d121944a44d6e5f3c" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "serde_json" +version = "1.0.113" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69801b70b1c3dac963ecb03a364ba0ceda9cf60c71cfe475e99864759c8b8a79" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "serde_spanned" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb3622f419d1296904700073ea6cc23ad690adbd66f13ea683df73298736f0c1" +dependencies = [ + "serde", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "serde_yaml" +version = "0.9.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adf8a49373e98a4c5f0ceb5d05aa7c648d75f63774981ed95b7c7443bbd50c6e" +dependencies = [ + "indexmap 2.2.2", + "itoa", + "ryu", + "serde", + "unsafe-libyaml", +] + +[[package]] +name = "service_common" +version = "0.1.0" +dependencies = [ + "arrow", + "datafusion", + "executor", + "tonic", + "workspace-hack", +] + +[[package]] +name = "service_grpc_flight" +version = "0.1.0" +dependencies = [ + "arrow", + "arrow-flight", + "assert_matches", + "async-trait", + "authz", + "bytes", + "data_types", + "datafusion", + "flightsql", + "futures", + "generated_types", + "iox_query", + "iox_query_influxql", + "iox_query_params", + "metric", + "observability_deps", + "prost", + "serde", + "serde_json", + "service_common", + "snafu 0.8.0", + "test_helpers", + "tokio", + "tonic", + "tower_trailer", + "trace", + "trace_http", + "tracker", + "workspace-hack", +] + +[[package]] +name = "service_grpc_testing" +version = "0.1.0" +dependencies = [ + "generated_types", + "observability_deps", + "tonic", + "workspace-hack", +] + +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "sharder" +version = "0.1.0" +dependencies = [ + "criterion", + "data_types", + "hashbrown 0.14.3", + "mutable_batch", + "mutable_batch_lp", + "parking_lot", + "rand", + "siphasher 1.0.0", + "workspace-hack", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "digest", + "rand_core", +] + +[[package]] +name = "similar" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32fea41aca09ee824cc9724996433064c89f7777e60762749a4170a14abbfa21" + +[[package]] +name = "siphasher" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" + +[[package]] +name = "siphasher" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54ac45299ccbd390721be55b412d41931911f654fa99e2cb8bfb57184b2061fe" + +[[package]] +name = "skeptic" +version = "0.13.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16d23b015676c90a0f01c197bfdc786c20342c73a0afdda9025adb0bc42940a8" +dependencies = [ + "bytecount", + "cargo_metadata", + "error-chain", + "glob", + "pulldown-cmark", + "tempfile", + "walkdir", +] + +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + +[[package]] +name = "smallvec" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" + +[[package]] +name = "snafu" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6" +dependencies = [ + "doc-comment", + "snafu-derive 0.7.5", +] + +[[package]] +name = "snafu" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d342c51730e54029130d7dc9fd735d28c4cd360f1368c01981d4f03ff207f096" +dependencies = [ + "snafu-derive 0.8.0", +] + +[[package]] +name = "snafu-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "snafu-derive" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "080c44971436b1af15d6f61ddd8b543995cf63ab8e677d46b00cc06f4ef267a0" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "snap" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" + +[[package]] +name = "socket2" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" +dependencies = [ + "libc", + "windows-sys 0.48.0", +] + +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der", +] + +[[package]] +name = "sqlformat" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce81b7bd7c4493975347ef60d8c7e8b742d4694f4c49f93e0a12ea263938176c" +dependencies = [ + "itertools 0.12.1", + "nom", + "unicode_categories", +] + +[[package]] +name = "sqlparser" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cc2c25a6c66789625ef164b4c7d2e548d627902280c13710d33da8222169964" +dependencies = [ + "log", + "sqlparser_derive", +] + +[[package]] +name = "sqlparser_derive" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "sqlx" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dba03c279da73694ef99763320dea58b51095dfe87d001b1d4b5fe78ba8763cf" +dependencies = [ + "sqlx-core", + "sqlx-macros", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", +] + +[[package]] +name = "sqlx-core" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d84b0a3c3739e220d94b3239fd69fb1f74bc36e16643423bd99de3b43c21bfbd" +dependencies = [ + "ahash", + "atoi", + "byteorder", + "bytes", + "crc", + "crossbeam-queue", + "dotenvy", + "either", + "event-listener 2.5.3", + "futures-channel", + "futures-core", + "futures-intrusive", + "futures-io", + "futures-util", + "hashlink", + "hex", + "indexmap 2.2.2", + "log", + "memchr", + "once_cell", + "paste", + "percent-encoding", + "rustls", + "rustls-pemfile", + "serde", + "serde_json", + "sha2", + "smallvec", + "sqlformat", + "thiserror", + "tokio", + "tokio-stream", + "tracing", + "url", + "uuid", + "webpki-roots", +] + +[[package]] +name = "sqlx-hotswap-pool" +version = "0.1.0" +dependencies = [ + "dotenvy", + "either", + "futures", + "rand", + "sqlx", + "tokio", + "workspace-hack", +] + +[[package]] +name = "sqlx-macros" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89961c00dc4d7dffb7aee214964b065072bff69e36ddb9e2c107541f75e4f2a5" +dependencies = [ + "proc-macro2", + "quote", + "sqlx-core", + "sqlx-macros-core", + "syn 1.0.109", +] + +[[package]] +name = "sqlx-macros-core" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0bd4519486723648186a08785143599760f7cc81c52334a55d6a83ea1e20841" +dependencies = [ + "atomic-write-file", + "dotenvy", + "either", + "heck", + "hex", + "once_cell", + "proc-macro2", + "quote", + "serde", + "serde_json", + "sha2", + "sqlx-core", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", + "syn 1.0.109", + "tempfile", + "tokio", + "url", +] + +[[package]] +name = "sqlx-mysql" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e37195395df71fd068f6e2082247891bc11e3289624bbc776a0cdfa1ca7f1ea4" +dependencies = [ + "atoi", + "base64", + "bitflags 2.4.2", + "byteorder", + "bytes", + "crc", + "digest", + "dotenvy", + "either", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "generic-array", + "hex", + "hkdf", + "hmac", + "itoa", + "log", + "md-5", + "memchr", + "once_cell", + "percent-encoding", + "rand", + "rsa", + "serde", + "sha1", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-postgres" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6ac0ac3b7ccd10cc96c7ab29791a7dd236bd94021f31eec7ba3d46a74aa1c24" +dependencies = [ + "atoi", + "base64", + "bitflags 2.4.2", + "byteorder", + "crc", + "dotenvy", + "etcetera", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "hex", + "hkdf", + "hmac", + "home", + "itoa", + "log", + "md-5", + "memchr", + "once_cell", + "rand", + "serde", + "serde_json", + "sha1", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-sqlite" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "210976b7d948c7ba9fced8ca835b11cbb2d677c59c79de41ac0d397e14547490" +dependencies = [ + "atoi", + "flume", + "futures-channel", + "futures-core", + "futures-executor", + "futures-intrusive", + "futures-util", + "libsqlite3-sys", + "log", + "percent-encoding", + "serde", + "sqlx-core", + "tracing", + "url", + "urlencoding", + "uuid", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "str_stack" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9091b6114800a5f2141aee1d1b9d6ca3592ac062dc5decb3764ec5895a47b4eb" + +[[package]] +name = "stringprep" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb41d74e231a107a1b4ee36bd1214b11285b77768d2e3824aedafa988fd36ee6" +dependencies = [ + "finl_unicode", + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + +[[package]] +name = "strum" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.48", +] + +[[package]] +name = "subtle" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" + +[[package]] +name = "symbolic-common" +version = "12.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cccfffbc6bb3bb2d3a26cd2077f4d055f6808d266f9d4d158797a4c60510dfe" +dependencies = [ + "debugid", + "memmap2", + "stable_deref_trait", + "uuid", +] + +[[package]] +name = "symbolic-demangle" +version = "12.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a99812da4020a67e76c4eb41f08c87364c14170495ff780f30dd519c221a68" +dependencies = [ + "cpp_demangle", + "rustc-demangle", + "symbolic-common", +] + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + +[[package]] +name = "synchronized-writer" +version = "1.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3543ca0810e71767052bdcdd5653f23998b192642a22c5164bfa6581e40a4a2" + +[[package]] +name = "sysinfo" +version = "0.30.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fb4f3438c8f6389c864e61221cbc97e9bca98b4daf39a5beb7bea660f528bb2" +dependencies = [ + "cfg-if", + "core-foundation-sys", + "libc", + "ntapi", + "once_cell", + "rayon", + "windows", +] + +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + +[[package]] +name = "tempfile" +version = "3.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a365e8cd18e44762ef95d87f284f4b5cd04107fec2ff3052bd6a3e6069669e67" +dependencies = [ + "cfg-if", + "fastrand", + "rustix", + "windows-sys 0.52.0", +] + +[[package]] +name = "termtree" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" + +[[package]] +name = "test_helpers" +version = "0.1.0" +dependencies = [ + "async-trait", + "dotenvy", + "observability_deps", + "parking_lot", + "tempfile", + "tokio", + "tracing-log", + "tracing-subscriber", + "workspace-hack", +] + +[[package]] +name = "test_helpers_end_to_end" +version = "0.1.0" +dependencies = [ + "arrow", + "arrow-flight", + "arrow_util", + "assert_cmd", + "assert_matches", + "bytes", + "data_types", + "dml", + "futures", + "generated_types", + "http", + "hyper", + "influxdb_iox_client", + "ingester_query_grpc", + "insta", + "iox_catalog", + "iox_query_params", + "mutable_batch_lp", + "mutable_batch_pb", + "nix 0.27.1", + "observability_deps", + "once_cell", + "parking_lot", + "prost", + "rand", + "regex", + "reqwest", + "serde_json", + "snafu 0.8.0", + "sqlx", + "tempfile", + "test_helpers", + "tokio", + "tokio-util", + "tonic", + "workspace-hack", +] + +[[package]] +name = "thiserror" +version = "1.0.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "thread_local" +version = "1.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" +dependencies = [ + "cfg-if", + "once_cell", +] + +[[package]] +name = "threadpool" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d050e60b33d41c19108b32cea32164033a9013fe3b46cbd4457559bfbf77afaa" +dependencies = [ + "num_cpus", +] + +[[package]] +name = "thrift" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e54bc85fc7faa8bc175c4bab5b92ba8d9a3ce893d0e9f42cc455c8ab16a9e09" +dependencies = [ + "byteorder", + "integer-encoding 3.0.4", + "log", + "ordered-float 2.10.1", + "threadpool", +] + +[[package]] +name = "tikv-jemalloc-sys" +version = "0.5.4+5.3.0-patched" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9402443cb8fd499b6f327e40565234ff34dbda27460c5b47db0db77443dd85d1" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "tiny-keccak" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "tinyvec" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "tokio" +version = "1.36.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" +dependencies = [ + "backtrace", + "bytes", + "libc", + "mio", + "num_cpus", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "tracing", + "windows-sys 0.48.0", +] + +[[package]] +name = "tokio-io-timeout" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30b74022ada614a1b4834de765f9bb43877f910cc8ce4be40e89042c9223a8bf" +dependencies = [ + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-macros" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls", + "tokio", +] + +[[package]] +name = "tokio-stream" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" +dependencies = [ + "bytes", + "futures-core", + "futures-io", + "futures-sink", + "pin-project-lite", + "slab", + "tokio", + "tracing", +] + +[[package]] +name = "tokio_metrics_bridge" +version = "0.1.0" +dependencies = [ + "metric", + "parking_lot", + "tokio", + "workspace-hack", +] + +[[package]] +name = "tokio_watchdog" +version = "0.1.0" +dependencies = [ + "metric", + "observability_deps", + "test_helpers", + "tokio", + "workspace-hack", +] + +[[package]] +name = "toml" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a9aad4a3066010876e8dcf5a8a06e70a558751117a145c6ce2b82c2e2054290" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c9ffdf896f8daaabf9b66ba8e77ea1ed5ed0f72821b398aba62352e95062951" +dependencies = [ + "indexmap 2.2.2", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + +[[package]] +name = "tonic" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d560933a0de61cf715926b9cac824d4c883c2c43142f787595e48280c40a1d0e" +dependencies = [ + "async-stream", + "async-trait", + "axum", + "base64", + "bytes", + "h2", + "http", + "http-body", + "hyper", + "hyper-timeout", + "percent-encoding", + "pin-project", + "prost", + "rustls", + "rustls-native-certs", + "rustls-pemfile", + "tokio", + "tokio-rustls", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tonic-build" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d021fc044c18582b9a2408cd0dd05b1596e3ecdb5c4df822bb0183545683889" +dependencies = [ + "prettyplease", + "proc-macro2", + "prost-build", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "tonic-health" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f80db390246dfb46553481f6024f0082ba00178ea495dbb99e70ba9a4fafb5e1" +dependencies = [ + "async-stream", + "prost", + "tokio", + "tokio-stream", + "tonic", +] + +[[package]] +name = "tonic-reflection" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fa37c513df1339d197f4ba21d28c918b9ef1ac1768265f11ecb6b7f1cba1b76" +dependencies = [ + "prost", + "prost-types", + "tokio", + "tokio-stream", + "tonic", +] + +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "indexmap 1.9.3", + "pin-project", + "pin-project-lite", + "rand", + "slab", + "tokio", + "tokio-util", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c5bb1d698276a2443e5ecfabc1008bf15a36c12e6a7176e7bf089ea9131140" +dependencies = [ + "base64", + "bitflags 2.4.2", + "bytes", + "futures-core", + "futures-util", + "http", + "http-body", + "http-range-header", + "mime", + "pin-project-lite", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" + +[[package]] +name = "tower-service" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" + +[[package]] +name = "tower_trailer" +version = "0.1.0" +dependencies = [ + "futures", + "http", + "http-body", + "parking_lot", + "pin-project", + "tower", + "workspace-hack", +] + +[[package]] +name = "trace" +version = "0.1.0" +dependencies = [ + "chrono", + "observability_deps", + "parking_lot", + "rand", + "workspace-hack", +] + +[[package]] +name = "trace_exporters" +version = "0.1.0" +dependencies = [ + "async-trait", + "chrono", + "clap", + "futures", + "iox_time", + "observability_deps", + "snafu 0.8.0", + "thrift", + "tokio", + "trace", + "workspace-hack", +] + +[[package]] +name = "trace_http" +version = "0.1.0" +dependencies = [ + "bytes", + "futures", + "hashbrown 0.14.3", + "http", + "http-body", + "itertools 0.12.1", + "metric", + "observability_deps", + "parking_lot", + "pin-project", + "snafu 0.8.0", + "tower", + "trace", + "workspace-hack", +] + +[[package]] +name = "tracing" +version = "0.1.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +dependencies = [ + "log", + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "tracing-core" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-serde" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc6b213177105856957181934e4920de57730fc69bf42c37ee5bb664d406d9e1" +dependencies = [ + "serde", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "parking_lot", + "regex", + "serde", + "serde_json", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", + "tracing-serde", +] + +[[package]] +name = "tracker" +version = "0.1.0" +dependencies = [ + "futures", + "hashbrown 0.14.3", + "iox_time", + "lock_api", + "metric", + "observability_deps", + "parking_lot", + "pin-project", + "sysinfo", + "tempfile", + "test_helpers", + "tokio", + "tokio-util", + "trace", + "workspace-hack", +] + +[[package]] +name = "treediff" +version = "4.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d127780145176e2b5d16611cc25a900150e86e9fd79d3bde6ff3a37359c9cb5" +dependencies = [ + "serde_json", +] + +[[package]] +name = "triomphe" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "859eb650cfee7434994602c3a68b25d77ad9e68c8a6cd491616ef86661382eb3" + +[[package]] +name = "trogging" +version = "0.1.0" +dependencies = [ + "clap", + "logfmt", + "observability_deps", + "regex", + "synchronized-writer", + "thiserror", + "tracing-log", + "tracing-subscriber", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "twox-hash" +version = "1.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" +dependencies = [ + "cfg-if", + "static_assertions", +] + +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + +[[package]] +name = "ucd-trie" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed646292ffc8188ef8ea4d1e0e0150fb15a5c2e12ad9b8fc191ae7a8a7f3c4b9" + +[[package]] +name = "unarray" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" + +[[package]] +name = "unicase" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7d2d4dafb69621809a81864c9c1b864479e1235c0dd4e199924b9742439ed89" +dependencies = [ + "version_check", +] + +[[package]] +name = "unicode-bidi" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "unicode-normalization" +version = "0.1.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "unicode-segmentation" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" + +[[package]] +name = "unicode-width" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" + +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + +[[package]] +name = "unsafe-libyaml" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab4c90930b95a82d00dc9e9ac071b4991924390d46cbd0dfe566148667605e4b" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "url" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + +[[package]] +name = "utf8parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" + +[[package]] +name = "uuid" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a" +dependencies = [ + "getrandom", +] + +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + +[[package]] +name = "wait-timeout" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f200f5b12eb75f8c1ed65abd4b2db8a6e1b138a20de009dacee265a2498f3f6" +dependencies = [ + "libc", +] + +[[package]] +name = "wal" +version = "0.1.0" +dependencies = [ + "assert_matches", + "byteorder", + "crc32fast", + "data_types", + "dml", + "generated_types", + "hashbrown 0.14.3", + "mutable_batch", + "mutable_batch_lp", + "mutable_batch_pb", + "observability_deps", + "parking_lot", + "prost", + "snafu 0.8.0", + "snap", + "test_helpers", + "tokio", + "workspace-hack", +] + +[[package]] +name = "wal_inspect" +version = "0.1.0" +dependencies = [ + "data_types", + "dml", + "generated_types", + "hashbrown 0.14.3", + "mutable_batch", + "mutable_batch_lp", + "mutable_batch_pb", + "parquet_to_line_protocol", + "schema", + "test_helpers", + "thiserror", + "tokio", + "wal", + "workspace-hack", +] + +[[package]] +name = "walkdir" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1e124130aee3fb58c5bdd6b639a0509486b0338acaaae0c84a5124b0f588b7f" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9e7e1900c352b609c8488ad12639a311045f40a35491fb69ba8c12f758af70b" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.48", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877b9c3f61ceea0e56331985743b13f3d25c406a7098d45180fb5f09bc19ed97" +dependencies = [ + "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b30af9e2d358182b5c7449424f017eba305ed32a7010509ede96cdc4696c46ed" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "642f325be6301eb8107a83d12a8ac6c1e1c54345a7ef1a9261962dfefda09e66" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f186bd2dcf04330886ce82d6f33dd75a7bfcf69ecf5763b89fcde53b6ac9838" + +[[package]] +name = "wasm-streams" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b65dc4c90b63b118468cf747d8bf3566c1913ef60be765b5730ead9e0a3ba129" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "web-sys" +version = "0.3.68" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96565907687f7aceb35bc5fc03770a8a0471d82e479f25832f54a0e3f4b28446" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "webpki-roots" +version = "0.25.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" + +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix", +] + +[[package]] +name = "whoami" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22fc3756b8a9133049b26c7f61ab35416c130e8c09b660f5b3958b446f52cc50" + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +dependencies = [ + "winapi", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" +dependencies = [ + "windows-core", + "windows-targets 0.52.0", +] + +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.0", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.0", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +dependencies = [ + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" + +[[package]] +name = "winnow" +version = "0.5.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5389a154b01683d28c77f8f68f49dea75f0a4da32557a58f68ee51ebba472d29" +dependencies = [ + "memchr", +] + +[[package]] +name = "winreg" +version = "0.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" +dependencies = [ + "cfg-if", + "windows-sys 0.48.0", +] + +[[package]] +name = "workspace-hack" +version = "0.1.0" +dependencies = [ + "ahash", + "arrow", + "arrow-ipc", + "base64", + "bitflags 2.4.2", + "byteorder", + "bytes", + "cc", + "chrono", + "clap", + "clap_builder", + "crossbeam-epoch", + "crossbeam-utils", + "crypto-common", + "digest", + "either", + "fixedbitset", + "flatbuffers", + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", + "getrandom", + "hashbrown 0.14.3", + "heck", + "hyper", + "hyper-rustls", + "indexmap 2.2.2", + "itertools 0.11.0", + "k8s-openapi", + "kube-core", + "libc", + "lock_api", + "log", + "md-5", + "memchr", + "mio", + "nix 0.27.1", + "nom", + "num-traits", + "object_store", + "once_cell", + "parking_lot", + "percent-encoding", + "petgraph", + "phf_shared", + "proptest", + "prost", + "prost-types", + "rand", + "rand_core", + "regex", + "regex-automata 0.4.5", + "regex-syntax 0.8.2", + "reqwest", + "ring", + "rustls", + "serde", + "serde_json", + "sha2", + "similar", + "spin 0.9.8", + "sqlparser", + "sqlx", + "sqlx-core", + "sqlx-macros", + "sqlx-macros-core", + "sqlx-postgres", + "sqlx-sqlite", + "strum", + "syn 1.0.109", + "syn 2.0.48", + "thrift", + "tokio", + "tokio-stream", + "tokio-util", + "tower", + "tower-http", + "tracing", + "tracing-core", + "tracing-log", + "tracing-subscriber", + "unicode-bidi", + "unicode-normalization", + "url", + "uuid", + "winapi", + "windows-sys 0.48.0", + "windows-sys 0.52.0", +] + +[[package]] +name = "xxhash-rust" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53be06678ed9e83edb1745eb72efc0bbcd7b5c3c35711a860906aed827a13d61" + +[[package]] +name = "xz2" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "388c44dc09d76f1536602ead6d325eb532f5c122f17782bd57fb47baeeb767e2" +dependencies = [ + "lzma-sys", +] + +[[package]] +name = "yaml-rust" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56c1936c4cc7a1c9ab21a1ebb602eb942ba868cbd44a99cb7cdc5892335e1c85" +dependencies = [ + "linked-hash-map", +] + +[[package]] +name = "yansi" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" + +[[package]] +name = "yansi" +version = "1.0.0-rc.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1367295b8f788d371ce2dbc842c7b709c73ee1364d30351dd300ec2203b12377" + +[[package]] +name = "zerocopy" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "zeroize" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" + +[[package]] +name = "zstd" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bffb3309596d527cfcba7dfc6ed6052f1d39dfbd7c867aa2e865e4a449c10110" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43747c7422e2924c11144d5229878b98180ef8b06cca4ab5af37afc8a8d8ea3e" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.9+zstd.1.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e16efa8a874a0481a574084d34cc26fdb3b99627480f785888deb6386506656" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..5aecdd5 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,146 @@ +[workspace] +# In alphabetical order +members = [ + "arrow_util", + "backoff", + "cache_system", + "clap_blocks", + "client_util", + "data_types", + "datafusion_util", + "dml", + "executor", + "flightsql", + "generated_types", + "grpc-binary-logger-proto", + "grpc-binary-logger-test-proto", + "grpc-binary-logger", + "import_export", + "influxdb_influxql_parser", + "influxdb_iox_client", + "influxdb_line_protocol", + "influxdb_storage_client", + "influxdb_tsm", + "influxdb2_client", + "influxrpc_parser", + "iox_catalog", + "iox_data_generator", + "iox_query_influxql", + "iox_query_influxrpc", + "iox_query", + "iox_tests", + "iox_time", + "ioxd_common", + "ioxd_test", + "logfmt", + "metric_exporters", + "metric", + "mutable_batch_lp", + "mutable_batch_pb", + "mutable_batch_tests", + "mutable_batch", + "object_store_metrics", + "observability_deps", + "panic_logging", + "parquet_file", + "parquet_to_line_protocol", + "predicate", + "query_functions", + "schema", + "service_common", + "service_grpc_flight", + "service_grpc_testing", + "sharder", + "sqlx-hotswap-pool", + "test_helpers_end_to_end", + "tokio_metrics_bridge", + "trace_exporters", + "trace_http", + "trace", + "tracker", + "trogging", + "wal_inspect", + "wal", + "workspace-hack", +] + +resolver = "2" + +exclude = [ + "*.md", + "*.txt", + ".git*", + ".github/", + "LICENSE*", + "massif.out.*", + "test_bench/", + "test_fixtures/", +] + +[workspace.package] +version = "0.1.0" +authors = ["IOx Project Developers"] +edition = "2021" +license = "MIT OR Apache-2.0" + +[workspace.dependencies] +arrow = { version = "49.0.0", features = ["prettyprint", "chrono-tz"] } +arrow-buffer = { version = "49.0.0" } +arrow-flight = { version = "49.0.0", features = ["flight-sql-experimental"] } +datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7" } +datafusion-proto = { git = "https://github.com/apache/arrow-datafusion.git", rev = "0e53c6d816f3a9d3d27c6ebb6d25b1699e5553e7" } +hashbrown = { version = "0.14.3" } +object_store = { version = "0.8.0" } +parquet = { version = "49.0.0", features = ["object_store"] } +pbjson = { version = "0.6.0" } +pbjson-build = { version = "0.6.2" } +pbjson-types = { version = "0.6.0" } +prost = { version = "0.12.3" } +prost-build = { version = "0.12.2" } +prost-types = { version = "0.12.3" } +sqlparser = { version = "0.41.0" } +tonic = { version = "0.10.2", features = ["tls", "tls-roots"] } +tonic-build = { version = "0.10.2" } +tonic-health = { version = "0.10.2" } +tonic-reflection = { version = "0.10.2" } + +[workspace.lints.rust] +rust_2018_idioms = "deny" +unreachable_pub = "deny" +missing_debug_implementations = "deny" +missing_copy_implementations = "deny" + +[workspace.lints.clippy] +dbg_macro = "deny" +todo = "deny" +clone_on_ref_ptr = "deny" +future_not_send = "deny" + +[workspace.lints.rustdoc] +broken_intra_doc_links = "deny" +bare_urls = "deny" + +# This profile optimizes for runtime performance and small binary size at the expense of longer +# build times. It's most suitable for final release builds. +[profile.release] +codegen-units = 16 +debug = true +lto = "thin" + +[profile.bench] +debug = true + +# This profile optimizes for short build times at the expense of larger binary size and slower +# runtime performance. It's most suitable for development iterations. +[profile.quick-release] +inherits = "release" +codegen-units = 16 +lto = false +incremental = true + +# Per insta docs: https://insta.rs/docs/quickstart/#optional-faster-runs +[profile.dev.package.insta] +opt-level = 3 + +[profile.dev.package.similar] +opt-level = 3 diff --git a/arrow_util/Cargo.toml b/arrow_util/Cargo.toml new file mode 100644 index 0000000..18ac4bf --- /dev/null +++ b/arrow_util/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "arrow_util" +description = "Apache Arrow utilities" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[lints] +workspace = true + +[dependencies] +ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } +arrow = { workspace = true } +# used by arrow anyway (needed for printing workaround) +chrono = { version = "0.4", default-features = false } +comfy-table = { version = "7.1", default-features = false } +hashbrown = { workspace = true } +num-traits = "0.2" +once_cell = { version = "1.19", features = ["parking_lot"] } +regex = "1.10.2" +snafu = "0.8" +uuid = "1" +workspace-hack = { version = "0.1", path = "../workspace-hack" } + +[dev-dependencies] +datafusion = { workspace = true } +proptest = { version = "1.4.0", default-features = false, features = ["std"] } +rand = "0.8.3" diff --git a/arrow_util/src/bitset.rs b/arrow_util/src/bitset.rs new file mode 100644 index 0000000..7fecee6 --- /dev/null +++ b/arrow_util/src/bitset.rs @@ -0,0 +1,879 @@ +use arrow::buffer::{BooleanBuffer, Buffer}; +use std::ops::Range; + +/// An arrow-compatible mutable bitset implementation +/// +/// Note: This currently operates on individual bytes at a time +/// it could be optimised to instead operate on usize blocks +#[derive(Debug, Default, Clone)] +pub struct BitSet { + /// The underlying data + /// + /// Data is stored in the least significant bit of a byte first + buffer: Vec, + + /// The length of this mask in bits + len: usize, +} + +impl BitSet { + /// Creates a new BitSet + pub fn new() -> Self { + Self::default() + } + + /// Construct an empty [`BitSet`] with a pre-allocated capacity for `n` + /// bits. + pub fn with_capacity(n: usize) -> Self { + Self { + buffer: Vec::with_capacity((n + 7) / 8), + len: 0, + } + } + + /// Creates a new BitSet with `count` unset bits. + pub fn with_size(count: usize) -> Self { + let mut bitset = Self::default(); + bitset.append_unset(count); + bitset + } + + /// Reserve space for `count` further bits + pub fn reserve(&mut self, count: usize) { + let new_buf_len = (self.len + count + 7) / 8; + self.buffer.reserve(new_buf_len); + } + + /// Appends `count` unset bits + pub fn append_unset(&mut self, count: usize) { + self.len += count; + let new_buf_len = (self.len + 7) / 8; + self.buffer.resize(new_buf_len, 0); + } + + /// Appends `count` set bits + pub fn append_set(&mut self, count: usize) { + let new_len = self.len + count; + let new_buf_len = (new_len + 7) / 8; + + let skew = self.len % 8; + if skew != 0 { + *self.buffer.last_mut().unwrap() |= 0xFF << skew; + } + + self.buffer.resize(new_buf_len, 0xFF); + + let rem = new_len % 8; + if rem != 0 { + *self.buffer.last_mut().unwrap() &= (1 << rem) - 1; + } + + self.len = new_len; + } + + /// Truncates the bitset to the provided length + pub fn truncate(&mut self, len: usize) { + let new_buf_len = (len + 7) / 8; + self.buffer.truncate(new_buf_len); + let overrun = len % 8; + if overrun > 0 { + *self.buffer.last_mut().unwrap() &= (1 << overrun) - 1; + } + self.len = len; + } + + /// Split this bitmap at the specified bit boundary, such that after this + /// call, `self` contains the range `[0, n)` and the returned value contains + /// `[n, len)`. + pub fn split_off(&mut self, n: usize) -> Self { + let mut right = Self::with_capacity(self.len - n); + right.extend_from_range(self, n..self.len); + + self.truncate(n); + + right + } + + /// Extends this [`BitSet`] by the context of `other` + pub fn extend_from(&mut self, other: &BitSet) { + self.append_bits(other.len, &other.buffer) + } + + /// Extends this [`BitSet`] by `range` elements in `other` + pub fn extend_from_range(&mut self, other: &BitSet, range: Range) { + let count = range.end - range.start; + if count == 0 { + return; + } + + let start_byte = range.start / 8; + let end_byte = (range.end + 7) / 8; + let skew = range.start % 8; + + // `append_bits` requires the provided `to_set` to be byte aligned, therefore + // if the range being copied is not byte aligned we must first append + // the leading bits to reach a byte boundary + if skew == 0 { + // No skew can simply append bytes directly + self.append_bits(count, &other.buffer[start_byte..end_byte]) + } else if start_byte + 1 == end_byte { + // Append bits from single byte + self.append_bits(count, &[other.buffer[start_byte] >> skew]) + } else { + // Append trailing bits from first byte to reach byte boundary, then append + // bits from the remaining byte-aligned mask + let offset = 8 - skew; + self.append_bits(offset, &[other.buffer[start_byte] >> skew]); + self.append_bits(count - offset, &other.buffer[(start_byte + 1)..end_byte]); + } + } + + /// Appends `count` boolean values from the slice of packed bits + pub fn append_bits(&mut self, count: usize, to_set: &[u8]) { + assert_eq!((count + 7) / 8, to_set.len()); + + let new_len = self.len + count; + let new_buf_len = (new_len + 7) / 8; + self.buffer.reserve(new_buf_len - self.buffer.len()); + + let whole_bytes = count / 8; + let overrun = count % 8; + + let skew = self.len % 8; + if skew == 0 { + self.buffer.extend_from_slice(&to_set[..whole_bytes]); + if overrun > 0 { + let masked = to_set[whole_bytes] & ((1 << overrun) - 1); + self.buffer.push(masked) + } + + self.len = new_len; + debug_assert_eq!(self.buffer.len(), new_buf_len); + return; + } + + for to_set_byte in &to_set[..whole_bytes] { + let low = *to_set_byte << skew; + let high = *to_set_byte >> (8 - skew); + + *self.buffer.last_mut().unwrap() |= low; + self.buffer.push(high); + } + + if overrun > 0 { + let masked = to_set[whole_bytes] & ((1 << overrun) - 1); + let low = masked << skew; + *self.buffer.last_mut().unwrap() |= low; + + if overrun > 8 - skew { + let high = masked >> (8 - skew); + self.buffer.push(high) + } + } + + self.len = new_len; + debug_assert_eq!(self.buffer.len(), new_buf_len); + } + + /// Sets a given bit + pub fn set(&mut self, idx: usize) { + assert!(idx <= self.len); + + let byte_idx = idx / 8; + let bit_idx = idx % 8; + self.buffer[byte_idx] |= 1 << bit_idx; + } + + /// Returns if the given index is set + pub fn get(&self, idx: usize) -> bool { + assert!(idx <= self.len); + + let byte_idx = idx / 8; + let bit_idx = idx % 8; + (self.buffer[byte_idx] >> bit_idx) & 1 != 0 + } + + /// Converts this BitSet to a buffer compatible with arrows boolean encoding + pub fn to_arrow(&self) -> BooleanBuffer { + let offset = 0; + BooleanBuffer::new(Buffer::from(&self.buffer), offset, self.len) + } + + /// Returns the number of values stored in the bitset + pub fn len(&self) -> usize { + self.len + } + + /// Returns if this bitset is empty + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns the number of bytes used by this bitset + pub fn byte_len(&self) -> usize { + self.buffer.len() + } + + /// Return the raw packed bytes used by this bitset + pub fn bytes(&self) -> &[u8] { + &self.buffer + } + + /// Return `true` if all bits in the [`BitSet`] are currently set. + pub fn is_all_set(&self) -> bool { + // An empty bitmap has no set bits. + if self.len == 0 { + return false; + } + + // Check all the bytes in the bitmap that have all their bits considered + // part of the bit set. + let full_blocks = (self.len / 8).saturating_sub(1); + if !self.buffer.iter().take(full_blocks).all(|&v| v == u8::MAX) { + return false; + } + + // Check the last byte of the bitmap that may only be partially part of + // the bit set, and therefore need masking to check only the relevant + // bits. + let mask = match self.len % 8 { + 1..=8 => !(0xFF << (self.len % 8)), // LSB mask + 0 => 0xFF, + _ => unreachable!(), + }; + *self.buffer.last().unwrap() == mask + } + + /// Return `true` if all bits in the [`BitSet`] are currently unset. + pub fn is_all_unset(&self) -> bool { + self.buffer.iter().all(|&v| v == 0) + } + + /// Returns the number of set bits in this bitmap. + pub fn count_ones(&self) -> usize { + // Invariant: the bits outside of [0, self.len) are always 0 + self.buffer.iter().map(|v| v.count_ones() as usize).sum() + } + + /// Returns the number of unset bits in this bitmap. + pub fn count_zeros(&self) -> usize { + self.len() - self.count_ones() + } + + /// Returns true if any bit is set (short circuiting). + pub fn is_any_set(&self) -> bool { + self.buffer.iter().any(|&v| v != 0) + } + + /// Returns a value [`Iterator`] that yields boolean values encoded in the + /// bitmap. + pub fn iter(&self) -> Iter<'_> { + Iter::new(self) + } + + /// Returns the bitwise AND between the two [`BitSet`] instances. + /// + /// # Panics + /// + /// Panics if the two sets have differing lengths. + pub fn and(&self, other: &Self) -> Self { + assert_eq!(self.len, other.len); + + Self { + buffer: self + .buffer + .iter() + .zip(other.buffer.iter()) + .map(|(a, b)| a & b) + .collect(), + len: self.len, + } + } +} + +/// A value iterator yielding the boolean values encoded in the bitmap. +#[derive(Debug)] +pub struct Iter<'a> { + /// A reference to the bitmap buffer. + buffer: &'a [u8], + /// The index of the next yielded bit in `buffer`. + idx: usize, + /// The number of bits stored in buffer. + len: usize, +} + +impl<'a> Iter<'a> { + fn new(b: &'a BitSet) -> Self { + Self { + buffer: &b.buffer, + idx: 0, + len: b.len(), + } + } +} + +impl<'a> Iterator for Iter<'a> { + type Item = bool; + + fn next(&mut self) -> Option { + if self.idx >= self.len { + return None; + } + + let byte_idx = self.idx / 8; + let shift = self.idx % 8; + + self.idx += 1; + + let byte = self.buffer[byte_idx]; + let byte = byte >> shift; + + Some(byte & 1 == 1) + } + + fn size_hint(&self) -> (usize, Option) { + let v = self.len - self.idx; + (v, Some(v)) + } +} + +impl<'a> ExactSizeIterator for Iter<'a> {} + +/// Returns an iterator over set bit positions in increasing order +pub fn iter_set_positions(bytes: &[u8]) -> impl Iterator + '_ { + iter_set_positions_with_offset(bytes, 0) +} + +/// Returns an iterator over set bit positions in increasing order starting +/// at the provided bit offset +pub fn iter_set_positions_with_offset( + bytes: &[u8], + offset: usize, +) -> impl Iterator + '_ { + let mut byte_idx = offset / 8; + let mut in_progress = bytes.get(byte_idx).cloned().unwrap_or(0); + + let skew = offset % 8; + in_progress &= 0xFF << skew; + + std::iter::from_fn(move || loop { + if in_progress != 0 { + let bit_pos = in_progress.trailing_zeros(); + in_progress ^= 1 << bit_pos; + return Some((byte_idx * 8) + (bit_pos as usize)); + } + byte_idx += 1; + in_progress = *bytes.get(byte_idx)?; + }) +} + +#[cfg(test)] +mod tests { + use arrow::array::BooleanBufferBuilder; + use proptest::prelude::*; + use rand::prelude::*; + use rand::rngs::OsRng; + + use super::*; + + /// Computes a compacted representation of a given bool array + fn compact_bools(bools: &[bool]) -> Vec { + bools + .chunks(8) + .map(|x| { + let mut collect = 0_u8; + for (idx, set) in x.iter().enumerate() { + if *set { + collect |= 1 << idx + } + } + collect + }) + .collect() + } + + fn iter_set_bools(bools: &[bool]) -> impl Iterator + '_ { + bools + .iter() + .enumerate() + .filter(|&(_x, y)| *y) + .map(|(x, _y)| x) + } + + #[test] + fn test_compact_bools() { + let bools = &[ + false, false, true, true, false, false, true, false, true, false, + ]; + let collected = compact_bools(bools); + let indexes: Vec<_> = iter_set_bools(bools).collect(); + assert_eq!(collected.as_slice(), &[0b01001100, 0b00000001]); + assert_eq!(indexes.as_slice(), &[2, 3, 6, 8]) + } + + #[test] + fn test_bit_mask() { + let mut mask = BitSet::new(); + + assert!(!mask.is_any_set()); + + mask.append_bits(8, &[0b11111111]); + let d1 = mask.buffer.clone(); + assert!(mask.is_any_set()); + + mask.append_bits(3, &[0b01010010]); + let d2 = mask.buffer.clone(); + + mask.append_bits(5, &[0b00010100]); + let d3 = mask.buffer.clone(); + + mask.append_bits(2, &[0b11110010]); + let d4 = mask.buffer.clone(); + + mask.append_bits(15, &[0b11011010, 0b01010101]); + let d5 = mask.buffer.clone(); + + assert_eq!(d1.as_slice(), &[0b11111111]); + assert_eq!(d2.as_slice(), &[0b11111111, 0b00000010]); + assert_eq!(d3.as_slice(), &[0b11111111, 0b10100010]); + assert_eq!(d4.as_slice(), &[0b11111111, 0b10100010, 0b00000010]); + assert_eq!( + d5.as_slice(), + &[0b11111111, 0b10100010, 0b01101010, 0b01010111, 0b00000001] + ); + + assert!(mask.get(0)); + assert!(!mask.get(8)); + assert!(mask.get(9)); + assert!(mask.get(19)); + } + + fn make_rng() -> StdRng { + let seed = OsRng.next_u64(); + println!("Seed: {seed}"); + StdRng::seed_from_u64(seed) + } + + #[test] + fn test_bit_mask_all_set() { + let mut mask = BitSet::new(); + let mut all_bools = vec![]; + let mut rng = make_rng(); + + for _ in 0..100 { + let mask_length = (rng.next_u32() % 50) as usize; + let bools: Vec<_> = std::iter::repeat(true).take(mask_length).collect(); + + let collected = compact_bools(&bools); + mask.append_bits(mask_length, &collected); + all_bools.extend_from_slice(&bools); + } + + let collected = compact_bools(&all_bools); + assert_eq!(mask.buffer, collected); + + let expected_indexes: Vec<_> = iter_set_bools(&all_bools).collect(); + let actual_indexes: Vec<_> = iter_set_positions(&mask.buffer).collect(); + assert_eq!(expected_indexes, actual_indexes); + } + + #[test] + fn test_bit_mask_fuzz() { + let mut mask = BitSet::new(); + let mut all_bools = vec![]; + let mut rng = make_rng(); + + for _ in 0..100 { + let mask_length = (rng.next_u32() % 50) as usize; + let bools: Vec<_> = std::iter::from_fn(|| Some(rng.next_u32() & 1 == 0)) + .take(mask_length) + .collect(); + + let collected = compact_bools(&bools); + mask.append_bits(mask_length, &collected); + all_bools.extend_from_slice(&bools); + } + + let collected = compact_bools(&all_bools); + assert_eq!(mask.buffer, collected); + + let expected_indexes: Vec<_> = iter_set_bools(&all_bools).collect(); + let actual_indexes: Vec<_> = iter_set_positions(&mask.buffer).collect(); + assert_eq!(expected_indexes, actual_indexes); + + if !all_bools.is_empty() { + for _ in 0..10 { + let offset = rng.next_u32() as usize % all_bools.len(); + + let expected_indexes: Vec<_> = iter_set_bools(&all_bools[offset..]) + .map(|x| x + offset) + .collect(); + + let actual_indexes: Vec<_> = + iter_set_positions_with_offset(&mask.buffer, offset).collect(); + + assert_eq!(expected_indexes, actual_indexes); + } + } + + for index in actual_indexes { + assert!(mask.get(index)); + } + } + + #[test] + fn test_append_fuzz() { + let mut mask = BitSet::new(); + let mut all_bools = vec![]; + let mut rng = make_rng(); + + for _ in 0..100 { + let len = (rng.next_u32() % 32) as usize; + let set = rng.next_u32() & 1 == 0; + + match set { + true => mask.append_set(len), + false => mask.append_unset(len), + } + + all_bools.extend(std::iter::repeat(set).take(len)); + + let collected = compact_bools(&all_bools); + assert_eq!(mask.buffer, collected); + } + } + + #[test] + fn test_truncate_fuzz() { + let mut mask = BitSet::new(); + let mut all_bools = vec![]; + let mut rng = make_rng(); + + for _ in 0..100 { + let mask_length = (rng.next_u32() % 32) as usize; + let bools: Vec<_> = std::iter::from_fn(|| Some(rng.next_u32() & 1 == 0)) + .take(mask_length) + .collect(); + + let collected = compact_bools(&bools); + mask.append_bits(mask_length, &collected); + all_bools.extend_from_slice(&bools); + + if !all_bools.is_empty() { + let truncate = rng.next_u32() as usize % all_bools.len(); + mask.truncate(truncate); + all_bools.truncate(truncate); + } + + let collected = compact_bools(&all_bools); + assert_eq!(mask.buffer, collected); + } + } + + #[test] + fn test_extend_range_fuzz() { + let mut rng = make_rng(); + let src_len = 32; + let src_bools: Vec<_> = std::iter::from_fn(|| Some(rng.next_u32() & 1 == 0)) + .take(src_len) + .collect(); + + let mut src_mask = BitSet::new(); + src_mask.append_bits(src_len, &compact_bools(&src_bools)); + + let mut dst_bools = Vec::new(); + let mut dst_mask = BitSet::new(); + + for _ in 0..100 { + let a = rng.next_u32() as usize % src_len; + let b = rng.next_u32() as usize % src_len; + + let start = a.min(b); + let end = a.max(b); + + dst_bools.extend_from_slice(&src_bools[start..end]); + dst_mask.extend_from_range(&src_mask, start..end); + + let collected = compact_bools(&dst_bools); + assert_eq!(dst_mask.buffer, collected); + } + } + + #[test] + fn test_arrow_compat() { + let bools = &[ + false, false, true, true, false, false, true, false, true, false, false, true, + ]; + + let mut builder = BooleanBufferBuilder::new(bools.len()); + builder.append_slice(bools); + let buffer = builder.finish(); + + let collected = compact_bools(bools); + let mut mask = BitSet::new(); + mask.append_bits(bools.len(), &collected); + let mask_buffer = mask.to_arrow(); + + assert_eq!(collected.as_slice(), buffer.values()); + assert_eq!(buffer.values(), mask_buffer.into_inner().as_slice()); + } + + #[test] + #[should_panic = "idx <= self.len"] + fn test_bitset_set_get_out_of_bounds() { + let mut v = BitSet::with_size(4); + + // The bitset is of length 4, which is backed by a single byte with 8 + // bits of storage capacity. + // + // Accessing bits past the 4 the bitset "contains" should not succeed. + + v.get(5); + v.set(5); + } + + #[test] + fn test_all_set_unset() { + for i in 1..100 { + let mut v = BitSet::new(); + assert!(!v.is_any_set()); + v.append_set(i); + assert!(v.is_all_set()); + assert!(!v.is_all_unset()); + assert!(v.is_any_set()); + + let mut v = BitSet::new(); + v.append_unset(i); + assert!(!v.is_any_set()); + v.append_set(1); + assert!(v.is_any_set()); + } + } + + #[test] + fn test_all_set_unset_multi_byte() { + let mut v = BitSet::new(); + + // Bitmap is composed of entirely set bits. + v.append_set(100); + assert!(v.is_all_set()); + assert!(!v.is_all_unset()); + + // Now the bitmap is neither composed of entirely set, nor entirely + // unset bits. + v.append_unset(1); + assert!(!v.is_all_set()); + assert!(!v.is_all_unset()); + + let mut v = BitSet::new(); + + // Bitmap is composed of entirely unset bits. + v.append_unset(100); + assert!(!v.is_all_set()); + assert!(v.is_all_unset()); + + // And once again, it is neither all set, nor all unset. + v.append_set(1); + assert!(!v.is_all_set()); + assert!(!v.is_all_unset()); + } + + #[test] + fn test_all_set_unset_single_byte() { + let mut v = BitSet::new(); + + // Bitmap is composed of entirely set bits. + v.append_set(2); + assert!(v.is_all_set()); + assert!(!v.is_all_unset()); + + // Now the bitmap is neither composed of entirely set, nor entirely + // unset bits. + v.append_unset(1); + assert!(!v.is_all_set()); + assert!(!v.is_all_unset()); + + let mut v = BitSet::new(); + + // Bitmap is composed of entirely unset bits. + v.append_unset(2); + assert!(!v.is_all_set()); + assert!(v.is_all_unset()); + + // And once again, it is neither all set, nor all unset. + v.append_set(1); + assert!(!v.is_all_set()); + assert!(!v.is_all_unset()); + } + + #[test] + fn test_all_set_unset_empty() { + let v = BitSet::new(); + assert!(!v.is_all_set()); + assert!(v.is_all_unset()); + } + + #[test] + fn test_split_byte_boundary() { + let mut a = BitSet::new(); + + a.append_set(16); + a.append_unset(8); + a.append_set(8); + + let b = a.split_off(16); + + assert_eq!(a.len(), 16); + assert_eq!(b.len(), 16); + + // All the bits in A are set. + assert!(a.is_all_set()); + for i in 0..16 { + assert!(a.get(i)); + } + + // The first 8 bits in b are unset, and the next 8 bits are set. + for i in 0..8 { + assert!(!b.get(i)); + } + for i in 8..16 { + assert!(b.get(i)); + } + } + + #[test] + fn test_split_sub_byte_boundary() { + let mut a = BitSet::new(); + + a.append_set(3); + a.append_unset(3); + a.append_set(1); + + assert_eq!(a.bytes(), &[0b01000111]); + + let b = a.split_off(5); + + assert_eq!(a.len(), 5); + assert_eq!(b.len(), 2); + + // A contains 3 set bits & 2 unset bits, with the rest masked out. + assert_eq!(a.bytes(), &[0b00000111]); + + // B contains 1 unset bit, and then 1 set bit + assert_eq!(b.bytes(), &[0b0000010]); + } + + #[test] + fn test_split_multi_byte_unclean_boundary() { + let mut a = BitSet::new(); + + a.append_set(8); + a.append_unset(1); + a.append_set(1); + a.append_unset(1); + a.append_set(1); + + assert_eq!(a.bytes(), &[0b11111111, 0b00001010]); + + let b = a.split_off(10); + + assert_eq!(a.len(), 10); + assert_eq!(b.len(), 2); + + assert_eq!(a.bytes(), &[0b11111111, 0b00000010]); + assert_eq!(b.bytes(), &[0b0000010]); + } + + #[test] + fn test_count_ones_with_truncate() { + // For varying sizes of bitmaps. + for i in 1..150 { + let mut b = BitSet::new(); + + // Set "i" number of bits in 2*i values. + for _ in 0..i { + b.append_unset(1); + b.append_set(1); + } + + assert_eq!(b.len(), 2 * i); + assert_eq!(b.count_ones(), i); + assert_eq!(b.count_zeros(), i); + + // Split it such that the last bit is removed. + let other = b.split_off((2 * i) - 1); + assert_eq!(other.len(), 1); + assert_eq!(other.count_ones(), 1); + assert_eq!(other.count_zeros(), 0); + + // Which means the original bitmap must now have 1 less 1 bit. + assert_eq!(b.len(), (2 * i) - 1); + assert_eq!(b.count_ones(), i - 1); + assert_eq!(b.count_zeros(), i); + } + } + + prop_compose! { + /// Returns a [`BitSet`] of random length and content. + fn arbitrary_bitset()( + values in prop::collection::vec(any::(), 0..20) + ) -> BitSet { + let mut b = BitSet::new(); + + for v in &values { + match v { + true => b.append_set(1), + false => b.append_unset(1), + } + } + + b + } + } + + proptest! { + #[test] + fn prop_iter( + values in prop::collection::vec(any::(), 0..20), + ) { + let mut b = BitSet::new(); + + for v in &values { + match v { + true => b.append_set(1), + false => b.append_unset(1), + } + } + + assert_eq!(values.len(), b.len()); + + let got = b.iter().collect::>(); + assert_eq!(values, got); + + // Exact size iter + assert_eq!(b.iter().len(), values.len()); + } + + #[test] + fn prop_and( + mut a in arbitrary_bitset(), + mut b in arbitrary_bitset(), + ) { + let min_len = a.len().min(b.len()); + // Truncate a and b to the same length. + a.truncate(min_len); + b.truncate(min_len); + + let want = a + .iter() + .zip(b.iter()) + .map(|(a, b)| a & b) + .collect::>(); + + let c = a.and(&b); + let got = c.iter().collect::>(); + + assert_eq!(got, want); + } + } +} diff --git a/arrow_util/src/dictionary.rs b/arrow_util/src/dictionary.rs new file mode 100644 index 0000000..1885deb --- /dev/null +++ b/arrow_util/src/dictionary.rs @@ -0,0 +1,299 @@ +//! Contains a structure to map from strings to integer symbols based on +//! string interning. +use std::convert::TryFrom; + +use arrow::array::{Array, ArrayDataBuilder, DictionaryArray}; +use arrow::buffer::NullBuffer; +use arrow::datatypes::{DataType, Int32Type}; +use hashbrown::HashMap; +use num_traits::{AsPrimitive, FromPrimitive, Zero}; +use snafu::Snafu; + +use crate::string::PackedStringArray; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("duplicate key found {}", key))] + DuplicateKeyFound { key: String }, +} + +/// A String dictionary that builds on top of `PackedStringArray` adding O(1) +/// index lookups for a given string +/// +/// Heavily inspired by the string-interner crate +#[derive(Debug, Clone)] +pub struct StringDictionary { + hash: ahash::RandomState, + /// Used to provide a lookup from string value to key type + /// + /// Note: K's hash implementation is not used, instead the raw entry + /// API is used to store keys w.r.t the hash of the strings themselves + /// + dedup: HashMap, + /// Used to store strings + storage: PackedStringArray, +} + +impl + FromPrimitive + Zero> Default for StringDictionary { + fn default() -> Self { + Self { + hash: ahash::RandomState::new(), + dedup: Default::default(), + storage: PackedStringArray::new(), + } + } +} + +impl + FromPrimitive + Zero> StringDictionary { + pub fn new() -> Self { + Default::default() + } + + pub fn with_capacity(keys: usize, values: usize) -> StringDictionary { + Self { + hash: Default::default(), + dedup: HashMap::with_capacity_and_hasher(keys, ()), + storage: PackedStringArray::with_capacity(keys, values), + } + } + + /// Returns the id corresponding to value, adding an entry for the + /// id if it is not yet present in the dictionary. + pub fn lookup_value_or_insert(&mut self, value: &str) -> K { + use hashbrown::hash_map::RawEntryMut; + + let hasher = &self.hash; + let storage = &mut self.storage; + let hash = hash_str(hasher, value); + + let entry = self + .dedup + .raw_entry_mut() + .from_hash(hash, |key| value == storage.get(key.as_()).unwrap()); + + match entry { + RawEntryMut::Occupied(entry) => *entry.into_key(), + RawEntryMut::Vacant(entry) => { + let index = storage.append(value); + let key = + K::from_usize(index).expect("failed to fit string index into dictionary key"); + *entry + .insert_with_hasher(hash, key, (), |key| { + let string = storage.get(key.as_()).unwrap(); + hash_str(hasher, string) + }) + .0 + } + } + } + + /// Returns the ID in self.dictionary that corresponds to `value`, if any. + pub fn lookup_value(&self, value: &str) -> Option { + let hash = hash_str(&self.hash, value); + self.dedup + .raw_entry() + .from_hash(hash, |key| value == self.storage.get(key.as_()).unwrap()) + .map(|(&symbol, &())| symbol) + } + + /// Returns the str in self.dictionary that corresponds to `id` + pub fn lookup_id(&self, id: K) -> Option<&str> { + self.storage.get(id.as_()) + } + + pub fn size(&self) -> usize { + self.storage.size() + self.dedup.len() * std::mem::size_of::() + } + + pub fn values(&self) -> &PackedStringArray { + &self.storage + } + + pub fn into_inner(self) -> PackedStringArray { + self.storage + } + + /// Truncates this dictionary removing all keys larger than `id` + pub fn truncate(&mut self, id: K) { + let id = id.as_(); + self.dedup.retain(|k, _| k.as_() <= id); + self.storage.truncate(id + 1) + } + + /// Clears this dictionary removing all elements + pub fn clear(&mut self) { + self.storage.clear(); + self.dedup.clear() + } +} + +fn hash_str(hasher: &ahash::RandomState, value: &str) -> u64 { + hasher.hash_one(value) +} + +impl StringDictionary { + /// Convert to an arrow representation with the provided set of + /// keys and an optional null bitmask + pub fn to_arrow(&self, keys: I, nulls: Option) -> DictionaryArray + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + // the nulls are recorded in the keys array, the dictionary itself + // is entirely non null + let dictionary_nulls = None; + let keys = keys.into_iter(); + + let array_data = ArrayDataBuilder::new(DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + )) + .len(keys.len()) + .add_buffer(keys.collect()) + .add_child_data(self.storage.to_arrow(dictionary_nulls).into_data()) + .nulls(nulls) + // TODO consider skipping the validation checks by using + // `build_unchecked()` + .build() + .expect("Valid array data"); + + DictionaryArray::::from(array_data) + } +} + +impl TryFrom> for StringDictionary +where + K: AsPrimitive + FromPrimitive + Zero, +{ + type Error = Error; + + fn try_from(storage: PackedStringArray) -> Result { + use hashbrown::hash_map::RawEntryMut; + + let hasher = ahash::RandomState::new(); + let mut dedup: HashMap = HashMap::with_capacity_and_hasher(storage.len(), ()); + for (idx, value) in storage.iter().enumerate() { + let hash = hash_str(&hasher, value); + + let entry = dedup + .raw_entry_mut() + .from_hash(hash, |key| value == storage.get(key.as_()).unwrap()); + + match entry { + RawEntryMut::Occupied(_) => { + return Err(Error::DuplicateKeyFound { + key: value.to_string(), + }) + } + RawEntryMut::Vacant(entry) => { + let key = + K::from_usize(idx).expect("failed to fit string index into dictionary key"); + + entry.insert_with_hasher(hash, key, (), |key| { + let string = storage.get(key.as_()).unwrap(); + hash_str(&hasher, string) + }); + } + } + } + + Ok(Self { + hash: hasher, + dedup, + storage, + }) + } +} + +#[cfg(test)] +mod test { + use std::convert::TryInto; + + use super::*; + + #[test] + fn test_dictionary() { + let mut dictionary = StringDictionary::::new(); + + let id1 = dictionary.lookup_value_or_insert("cupcake"); + let id2 = dictionary.lookup_value_or_insert("cupcake"); + let id3 = dictionary.lookup_value_or_insert("womble"); + + let id4 = dictionary.lookup_value("cupcake").unwrap(); + let id5 = dictionary.lookup_value("womble").unwrap(); + + let cupcake = dictionary.lookup_id(id4).unwrap(); + let womble = dictionary.lookup_id(id5).unwrap(); + + let arrow_expected = arrow::array::StringArray::from(vec!["cupcake", "womble"]); + let arrow_actual = dictionary.values().to_arrow(None); + + assert_eq!(id1, id2); + assert_eq!(id1, id4); + assert_ne!(id1, id3); + assert_eq!(id3, id5); + + assert_eq!(cupcake, "cupcake"); + assert_eq!(womble, "womble"); + + assert!(dictionary.lookup_value("foo").is_none()); + assert!(dictionary.lookup_id(-1).is_none()); + assert_eq!(arrow_expected, arrow_actual); + } + + #[test] + fn from_string_array() { + let mut data = PackedStringArray::::new(); + data.append("cupcakes"); + data.append("foo"); + data.append("bingo"); + + let dictionary: StringDictionary<_> = data.try_into().unwrap(); + + assert_eq!(dictionary.lookup_value("cupcakes"), Some(0)); + assert_eq!(dictionary.lookup_value("foo"), Some(1)); + assert_eq!(dictionary.lookup_value("bingo"), Some(2)); + + assert_eq!(dictionary.lookup_id(0), Some("cupcakes")); + assert_eq!(dictionary.lookup_id(1), Some("foo")); + assert_eq!(dictionary.lookup_id(2), Some("bingo")); + } + + #[test] + fn from_string_array_duplicates() { + let mut data = PackedStringArray::::new(); + data.append("cupcakes"); + data.append("foo"); + data.append("bingo"); + data.append("cupcakes"); + + let err = TryInto::>::try_into(data).expect_err("expected failure"); + assert!(matches!(err, Error::DuplicateKeyFound { key } if &key == "cupcakes")) + } + + #[test] + fn test_truncate() { + let mut dictionary = StringDictionary::::new(); + dictionary.lookup_value_or_insert("cupcake"); + dictionary.lookup_value_or_insert("cupcake"); + dictionary.lookup_value_or_insert("bingo"); + let bingo = dictionary.lookup_value_or_insert("bingo"); + let bongo = dictionary.lookup_value_or_insert("bongo"); + dictionary.lookup_value_or_insert("bingo"); + dictionary.lookup_value_or_insert("cupcake"); + + dictionary.truncate(bingo); + + assert_eq!(dictionary.values().len(), 2); + assert_eq!(dictionary.dedup.len(), 2); + + assert_eq!(dictionary.lookup_value("cupcake"), Some(0)); + assert_eq!(dictionary.lookup_value("bingo"), Some(1)); + + assert!(dictionary.lookup_value("bongo").is_none()); + assert!(dictionary.lookup_id(bongo).is_none()); + + dictionary.lookup_value_or_insert("bongo"); + assert_eq!(dictionary.lookup_value("bongo"), Some(2)); + } +} diff --git a/arrow_util/src/display.rs b/arrow_util/src/display.rs new file mode 100644 index 0000000..cba4b91 --- /dev/null +++ b/arrow_util/src/display.rs @@ -0,0 +1,206 @@ +use arrow::array::{ArrayRef, DurationNanosecondArray, TimestampNanosecondArray}; +use arrow::datatypes::{DataType, TimeUnit}; +use arrow::error::{ArrowError, Result}; +use arrow::record_batch::RecordBatch; + +use comfy_table::{Cell, Table}; + +use chrono::prelude::*; + +/// custom version of +/// [pretty_format_batches](arrow::util::pretty::pretty_format_batches) +/// that displays timestamps using RFC3339 format (e.g. `2021-07-20T23:28:50Z`) +/// +/// Should be removed if/when the capability is added upstream to arrow: +/// +pub fn pretty_format_batches(results: &[RecordBatch]) -> Result { + Ok(create_table(results)?.to_string()) +} + +/// Convert the value at `column[row]` to a String +/// +/// Special cases printing Timestamps in RFC3339 for IOx, otherwise +/// falls back to Arrow's implementation +/// +fn array_value_to_string(column: &ArrayRef, row: usize) -> Result { + match column.data_type() { + DataType::Timestamp(TimeUnit::Nanosecond, None) if column.is_valid(row) => { + let ts_column = column + .as_any() + .downcast_ref::() + .unwrap(); + + let ts_value = ts_column.value(row); + const NANOS_IN_SEC: i64 = 1_000_000_000; + let secs = ts_value / NANOS_IN_SEC; + let nanos = (ts_value - (secs * NANOS_IN_SEC)) as u32; + let ts = NaiveDateTime::from_timestamp_opt(secs, nanos).ok_or_else(|| { + ArrowError::ExternalError( + format!("Cannot process timestamp (secs={secs}, nanos={nanos})").into(), + ) + })?; + // treat as UTC + let ts = DateTime::::from_naive_utc_and_offset(ts, Utc); + // convert to string in preferred influx format + let use_z = true; + Ok(ts.to_rfc3339_opts(SecondsFormat::AutoSi, use_z)) + } + // TODO(edd): see https://github.com/apache/arrow-rs/issues/1168 + DataType::Duration(TimeUnit::Nanosecond) if column.is_valid(row) => { + let dur_column = column + .as_any() + .downcast_ref::() + .unwrap(); + + let duration = std::time::Duration::from_nanos( + dur_column + .value(row) + .try_into() + .map_err(|e| ArrowError::InvalidArgumentError(format!("{e:?}")))?, + ); + Ok(format!("{duration:?}")) + } + _ => { + // fallback to arrow's default printing for other types + arrow::util::display::array_value_to_string(column, row) + } + } +} + +/// Convert a series of record batches into a table +/// +/// NB: COPIED FROM ARROW +fn create_table(results: &[RecordBatch]) -> Result { + let mut table = Table::new(); + table.load_preset("||--+-++| ++++++"); + + if results.is_empty() { + return Ok(table); + } + + let schema = results[0].schema(); + + let mut header = Vec::new(); + for field in schema.fields() { + header.push(Cell::new(field.name())); + } + table.set_header(header); + + for (i, batch) in results.iter().enumerate() { + if batch.schema() != schema { + return Err(ArrowError::SchemaError(format!( + "Batches have different schemas:\n\nFirst:\n{}\n\nBatch {}:\n{}", + schema, + i + 1, + batch.schema() + ))); + } + + for row in 0..batch.num_rows() { + let mut cells = Vec::new(); + for col in 0..batch.num_columns() { + let column = batch.column(col); + cells.push(Cell::new(array_value_to_string(column, row)?)); + } + table.add_row(cells); + } + } + + Ok(table) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + use arrow::{ + array::{ + ArrayRef, BooleanArray, DictionaryArray, Float64Array, Int64Array, StringArray, + UInt64Array, + }, + datatypes::Int32Type, + }; + use datafusion::common::assert_contains; + + #[test] + fn test_formatting() { + // tests formatting all of the Arrow array types used in IOx + + // tags use string dictionary + let dict_array: ArrayRef = Arc::new( + vec![Some("a"), None, Some("b")] + .into_iter() + .collect::>(), + ); + + // field types + let int64_array: ArrayRef = + Arc::new([Some(-1), None, Some(2)].iter().collect::()); + let uint64_array: ArrayRef = + Arc::new([Some(1), None, Some(2)].iter().collect::()); + let float64_array: ArrayRef = Arc::new( + [Some(1.0), None, Some(2.0)] + .iter() + .collect::(), + ); + let bool_array: ArrayRef = Arc::new( + [Some(true), None, Some(false)] + .iter() + .collect::(), + ); + let string_array: ArrayRef = Arc::new( + vec![Some("foo"), None, Some("bar")] + .into_iter() + .collect::(), + ); + + // timestamp type + let ts_array: ArrayRef = Arc::new( + [None, Some(100), Some(1626823730000000000)] + .iter() + .collect::(), + ); + + let batch = RecordBatch::try_from_iter(vec![ + ("dict", dict_array), + ("int64", int64_array), + ("uint64", uint64_array), + ("float64", float64_array), + ("bool", bool_array), + ("string", string_array), + ("time", ts_array), + ]) + .unwrap(); + + let table = pretty_format_batches(&[batch]).unwrap(); + + let expected = vec![ + "+------+-------+--------+---------+-------+--------+--------------------------------+", + "| dict | int64 | uint64 | float64 | bool | string | time |", + "+------+-------+--------+---------+-------+--------+--------------------------------+", + "| a | -1 | 1 | 1.0 | true | foo | |", + "| | | | | | | 1970-01-01T00:00:00.000000100Z |", + "| b | 2 | 2 | 2.0 | false | bar | 2021-07-20T23:28:50Z |", + "+------+-------+--------+---------+-------+--------+--------------------------------+", + ]; + + let actual: Vec<&str> = table.lines().collect(); + assert_eq!( + expected, actual, + "Expected:\n\n{expected:#?}\nActual:\n\n{actual:#?}\n" + ); + } + + #[test] + fn test_pretty_format_batches_checks_schemas() { + let int64_array: ArrayRef = Arc::new([Some(2)].iter().collect::()); + let uint64_array: ArrayRef = Arc::new([Some(2)].iter().collect::()); + + let batch1 = RecordBatch::try_from_iter(vec![("col", int64_array)]).unwrap(); + let batch2 = RecordBatch::try_from_iter(vec![("col", uint64_array)]).unwrap(); + + let err = pretty_format_batches(&[batch1, batch2]).unwrap_err(); + assert_contains!(err.to_string(), "Batches have different schemas:"); + } +} diff --git a/arrow_util/src/flight.rs b/arrow_util/src/flight.rs new file mode 100644 index 0000000..66521aa --- /dev/null +++ b/arrow_util/src/flight.rs @@ -0,0 +1,26 @@ +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; + +/// Prepare an arrow Schema for transport over the Arrow Flight protocol +/// +/// Converts dictionary types to underlying types due to +pub fn prepare_schema_for_flight(schema: SchemaRef) -> SchemaRef { + let fields: Fields = schema + .fields() + .iter() + .map(|field| match field.data_type() { + DataType::Dictionary(_, value_type) => Arc::new( + Field::new( + field.name(), + value_type.as_ref().clone(), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()), + ), + _ => Arc::clone(field), + }) + .collect(); + + Arc::new(Schema::new(fields).with_metadata(schema.metadata().clone())) +} diff --git a/arrow_util/src/lib.rs b/arrow_util/src/lib.rs new file mode 100644 index 0000000..613d794 --- /dev/null +++ b/arrow_util/src/lib.rs @@ -0,0 +1,27 @@ +#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)] +#![allow(clippy::clone_on_ref_ptr)] +#![warn( + missing_copy_implementations, + missing_debug_implementations, + clippy::explicit_iter_loop, + // See https://github.com/influxdata/influxdb_iox/pull/1671 + clippy::future_not_send, + clippy::clone_on_ref_ptr, + clippy::todo, + clippy::dbg_macro, + unused_crate_dependencies +)] + +// Workaround for "unused crate" lint false positives. +use workspace_hack as _; + +pub mod bitset; +pub mod dictionary; +pub mod display; +pub mod flight; +pub mod optimize; +pub mod string; +pub mod util; + +/// This has a collection of testing helper functions +pub mod test_util; diff --git a/arrow_util/src/optimize.rs b/arrow_util/src/optimize.rs new file mode 100644 index 0000000..2f7ffbf --- /dev/null +++ b/arrow_util/src/optimize.rs @@ -0,0 +1,299 @@ +use std::collections::BTreeSet; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, DictionaryArray, StringArray}; +use arrow::datatypes::{DataType, Int32Type}; +use arrow::error::{ArrowError, Result}; +use arrow::record_batch::RecordBatch; +use hashbrown::HashMap; + +use crate::dictionary::StringDictionary; + +/// Takes a record batch and returns a new record batch with dictionaries +/// optimized to contain no duplicate or unreferenced values +/// +/// Where the input dictionaries are sorted, the output dictionaries +/// will also be +pub fn optimize_dictionaries(batch: &RecordBatch) -> Result { + let schema = batch.schema(); + let new_columns = batch + .columns() + .iter() + .zip(schema.fields()) + .map(|(col, field)| match field.data_type() { + DataType::Dictionary(key, value) => optimize_dict_col(col, key, value), + _ => Ok(Arc::clone(col)), + }) + .collect::>>()?; + + RecordBatch::try_new(schema, new_columns) +} + +/// Optimizes the dictionaries for a column +fn optimize_dict_col( + col: &ArrayRef, + key_type: &DataType, + value_type: &DataType, +) -> Result { + if key_type != &DataType::Int32 { + return Err(ArrowError::NotYetImplemented(format!( + "truncating non-Int32 dictionaries not supported: {key_type}" + ))); + } + + if value_type != &DataType::Utf8 { + return Err(ArrowError::NotYetImplemented(format!( + "truncating non-string dictionaries not supported: {value_type}" + ))); + } + + let col = col + .as_any() + .downcast_ref::>() + .expect("unexpected datatype"); + + let keys = col.keys(); + let values = col.values(); + let values = values + .as_any() + .downcast_ref::() + .expect("unexpected datatype"); + + // The total length of the resulting values array + let mut values_len = 0_usize; + + // Keys that appear in the values array + // Use a BTreeSet to preserve the order of the dictionary + let mut used_keys = BTreeSet::new(); + for key in keys.iter().flatten() { + if used_keys.insert(key) { + values_len += values.value_length(key as usize) as usize; + } + } + + // Then perform deduplication + let mut new_dictionary = StringDictionary::with_capacity(used_keys.len(), values_len); + let mut old_to_new_idx: HashMap = HashMap::with_capacity(used_keys.len()); + for key in used_keys { + let new_key = new_dictionary.lookup_value_or_insert(values.value(key as usize)); + old_to_new_idx.insert(key, new_key); + } + + let new_keys = keys.iter().map(|x| match x { + Some(x) => *old_to_new_idx.get(&x).expect("no mapping found"), + None => -1, + }); + + let nulls = keys.nulls().cloned(); + Ok(Arc::new(new_dictionary.to_arrow(new_keys, nulls))) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate as arrow_util; + use crate::assert_batches_eq; + use arrow::array::{ArrayDataBuilder, DictionaryArray, Float64Array, Int32Array, StringArray}; + use arrow::compute::concat; + use std::iter::FromIterator; + + #[test] + fn test_optimize_dictionaries() { + let values = StringArray::from(vec![ + "duplicate", + "duplicate", + "foo", + "boo", + "unused", + "duplicate", + ]); + let keys = Int32Array::from(vec![ + Some(0), + Some(1), + None, + Some(1), + Some(2), + Some(5), + Some(3), + ]); + + let batch = RecordBatch::try_from_iter(vec![( + "foo", + Arc::new(build_dict(keys, values)) as ArrayRef, + )]) + .unwrap(); + + let optimized = optimize_dictionaries(&batch).unwrap(); + + let col = optimized + .column(0) + .as_any() + .downcast_ref::>() + .unwrap(); + + let values = col.values(); + let values = values.as_any().downcast_ref::().unwrap(); + let values = values.iter().flatten().collect::>(); + assert_eq!(values, vec!["duplicate", "foo", "boo"]); + + assert_batches_eq!( + vec![ + "+-----------+", + "| foo |", + "+-----------+", + "| duplicate |", + "| duplicate |", + "| |", + "| duplicate |", + "| foo |", + "| duplicate |", + "| boo |", + "+-----------+", + ], + &[optimized] + ); + } + + #[test] + fn test_optimize_dictionaries_concat() { + let f1_1 = Float64Array::from(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)]); + let t2_1 = DictionaryArray::::from_iter(vec![ + Some("a"), + Some("g"), + Some("a"), + Some("b"), + ]); + let t1_1 = DictionaryArray::::from_iter(vec![ + Some("a"), + Some("a"), + Some("b"), + Some("b"), + ]); + + let f1_2 = Float64Array::from(vec![Some(1.0), Some(5.0), Some(2.0), Some(46.0)]); + let t2_2 = DictionaryArray::::from_iter(vec![ + Some("a"), + Some("b"), + Some("a"), + Some("a"), + ]); + let t1_2 = DictionaryArray::::from_iter(vec![ + Some("a"), + Some("d"), + Some("a"), + Some("b"), + ]); + + let concat = RecordBatch::try_from_iter(vec![ + ("f1", concat(&[&f1_1, &f1_2]).unwrap()), + ("t2", concat(&[&t2_1, &t2_2]).unwrap()), + ("t1", concat(&[&t1_1, &t1_2]).unwrap()), + ]) + .unwrap(); + + let optimized = optimize_dictionaries(&concat).unwrap(); + + let col = optimized + .column(optimized.schema().column_with_name("t2").unwrap().0) + .as_any() + .downcast_ref::>() + .unwrap(); + + let values = col.values(); + let values = values.as_any().downcast_ref::().unwrap(); + let values = values.iter().flatten().collect::>(); + assert_eq!(values, vec!["a", "g", "b"]); + + let col = optimized + .column(optimized.schema().column_with_name("t1").unwrap().0) + .as_any() + .downcast_ref::>() + .unwrap(); + + let values = col.values(); + let values = values.as_any().downcast_ref::().unwrap(); + let values = values.iter().flatten().collect::>(); + assert_eq!(values, vec!["a", "b", "d"]); + + assert_batches_eq!( + vec![ + "+------+----+----+", + "| f1 | t2 | t1 |", + "+------+----+----+", + "| 1.0 | a | a |", + "| 2.0 | g | a |", + "| 3.0 | a | b |", + "| 4.0 | b | b |", + "| 1.0 | a | a |", + "| 5.0 | b | d |", + "| 2.0 | a | a |", + "| 46.0 | a | b |", + "+------+----+----+", + ], + &[optimized] + ); + } + + #[test] + fn test_optimize_dictionaries_null() { + let values = StringArray::from(vec!["bananas"]); + let keys = Int32Array::from(vec![None, None, Some(0)]); + let col = Arc::new(build_dict(keys, values)) as ArrayRef; + + let col = optimize_dict_col(&col, &DataType::Int32, &DataType::Utf8).unwrap(); + + let batch = RecordBatch::try_from_iter(vec![("t", col)]).unwrap(); + + assert_batches_eq!( + vec![ + "+---------+", + "| t |", + "+---------+", + "| |", + "| |", + "| bananas |", + "+---------+", + ], + &[batch] + ); + } + + #[test] + fn test_optimize_dictionaries_slice() { + let values = StringArray::from(vec!["bananas"]); + let keys = Int32Array::from(vec![None, Some(0), None]); + let col = Arc::new(build_dict(keys, values)) as ArrayRef; + let col = col.slice(1, 2); + + let col = optimize_dict_col(&col, &DataType::Int32, &DataType::Utf8).unwrap(); + + let batch = RecordBatch::try_from_iter(vec![("t", col)]).unwrap(); + + assert_batches_eq!( + vec![ + "+---------+", + "| t |", + "+---------+", + "| bananas |", + "| |", + "+---------+", + ], + &[batch] + ); + } + + fn build_dict(keys: Int32Array, values: StringArray) -> DictionaryArray { + let data = ArrayDataBuilder::new(DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + )) + .len(keys.len()) + .add_buffer(keys.to_data().buffers()[0].clone()) + .nulls(keys.nulls().cloned()) + .add_child_data(values.into_data()) + .build() + .unwrap(); + + DictionaryArray::from(data) + } +} diff --git a/arrow_util/src/string.rs b/arrow_util/src/string.rs new file mode 100644 index 0000000..5460a38 --- /dev/null +++ b/arrow_util/src/string.rs @@ -0,0 +1,384 @@ +use arrow::array::ArrayDataBuilder; +use arrow::array::StringArray; +use arrow::buffer::Buffer; +use arrow::buffer::NullBuffer; +use num_traits::{AsPrimitive, FromPrimitive, Zero}; +use std::fmt::Debug; +use std::ops::Range; + +/// A packed string array that stores start and end indexes into +/// a contiguous string slice. +/// +/// The type parameter K alters the type used to store the offsets +#[derive(Debug, Clone)] +pub struct PackedStringArray { + /// The start and end offsets of strings stored in storage + offsets: Vec, + /// A contiguous array of string data + storage: String, +} + +impl Default for PackedStringArray { + fn default() -> Self { + Self { + offsets: vec![K::zero()], + storage: String::new(), + } + } +} + +impl + FromPrimitive + Zero> PackedStringArray { + pub fn new() -> Self { + Self::default() + } + + pub fn new_empty(len: usize) -> Self { + Self { + offsets: vec![K::zero(); len + 1], + storage: String::new(), + } + } + + pub fn with_capacity(keys: usize, values: usize) -> Self { + let mut offsets = Vec::with_capacity(keys + 1); + offsets.push(K::zero()); + + Self { + offsets, + storage: String::with_capacity(values), + } + } + + /// Append a value + /// + /// Returns the index of the appended data + pub fn append(&mut self, data: &str) -> usize { + let id = self.offsets.len() - 1; + + let offset = self.storage.len() + data.len(); + let offset = K::from_usize(offset).expect("failed to fit into offset type"); + + self.offsets.push(offset); + self.storage.push_str(data); + + id + } + + /// Extends this [`PackedStringArray`] by the contents of `other` + pub fn extend_from(&mut self, other: &PackedStringArray) { + let offset = self.storage.len(); + self.storage.push_str(other.storage.as_str()); + // Copy offsets skipping the first element as this string start delimiter is already + // provided by the end delimiter of the current offsets array + self.offsets.extend( + other + .offsets + .iter() + .skip(1) + .map(|x| K::from_usize(x.as_() + offset).expect("failed to fit into offset type")), + ) + } + + /// Extends this [`PackedStringArray`] by `range` elements from `other` + pub fn extend_from_range(&mut self, other: &PackedStringArray, range: Range) { + let first_offset: usize = other.offsets[range.start].as_(); + let end_offset: usize = other.offsets[range.end].as_(); + + let insert_offset = self.storage.len(); + + self.storage + .push_str(&other.storage[first_offset..end_offset]); + + self.offsets.extend( + other.offsets[(range.start + 1)..(range.end + 1)] + .iter() + .map(|x| { + K::from_usize(x.as_() - first_offset + insert_offset) + .expect("failed to fit into offset type") + }), + ) + } + + /// Get the value at a given index + pub fn get(&self, index: usize) -> Option<&str> { + let start_offset = self.offsets.get(index)?.as_(); + let end_offset = self.offsets.get(index + 1)?.as_(); + + Some(&self.storage[start_offset..end_offset]) + } + + /// Pads with empty strings to reach length + pub fn extend(&mut self, len: usize) { + let offset = K::from_usize(self.storage.len()).expect("failed to fit into offset type"); + self.offsets.resize(self.offsets.len() + len, offset); + } + + /// Truncates the array to the given length + pub fn truncate(&mut self, len: usize) { + self.offsets.truncate(len + 1); + let last_idx = self.offsets.last().expect("offsets empty"); + self.storage.truncate(last_idx.as_()); + } + + /// Removes all elements from this array + pub fn clear(&mut self) { + self.offsets.truncate(1); + self.storage.clear(); + } + + pub fn iter(&self) -> PackedStringIterator<'_, K> { + PackedStringIterator { + array: self, + index: 0, + } + } + + /// The number of strings in this array + pub fn len(&self) -> usize { + self.offsets.len() - 1 + } + + pub fn is_empty(&self) -> bool { + self.offsets.len() == 1 + } + + /// Return the amount of memory in bytes taken up by this array + pub fn size(&self) -> usize { + self.storage.capacity() + self.offsets.capacity() * std::mem::size_of::() + } + + pub fn inner(&self) -> (&[K], &str) { + (&self.offsets, &self.storage) + } + + pub fn into_inner(self) -> (Vec, String) { + (self.offsets, self.storage) + } + + /// Split this [`PackedStringArray`] at `n`, such that `self`` contains the + /// elements `[0, n)` and the returned [`PackedStringArray`] contains + /// elements `[n, len)`. + pub fn split_off(&mut self, n: usize) -> Self { + if n > self.len() { + return Default::default(); + } + + let offsets = self.offsets.split_off(n + 1); + + // Figure out where to split the string storage. + let split_point = self.offsets.last().map(|v| v.as_()).unwrap(); + + // Split the storage at the split point, such that the first N values + // appear in self. + let storage = self.storage.split_off(split_point); + + // The new "offsets" now needs remapping such that the first offset + // starts at 0, so that indexing into the new storage string will hit + // the right start point. + let offsets = std::iter::once(K::zero()) + .chain( + offsets + .into_iter() + .map(|v| K::from_usize(v.as_() - split_point).unwrap()), + ) + .collect::>(); + + Self { offsets, storage } + } +} + +impl PackedStringArray { + /// Convert to an arrow with an optional null bitmask + pub fn to_arrow(&self, nulls: Option) -> StringArray { + let len = self.offsets.len() - 1; + let offsets = Buffer::from_slice_ref(&self.offsets); + let values = Buffer::from(self.storage.as_bytes()); + + let data = ArrayDataBuilder::new(arrow::datatypes::DataType::Utf8) + .len(len) + .add_buffer(offsets) + .add_buffer(values) + .nulls(nulls) + .build() + // TODO consider skipping the validation checks by using + // `new_unchecked` + .expect("Valid array data"); + StringArray::from(data) + } +} + +#[derive(Debug)] +pub struct PackedStringIterator<'a, K> { + array: &'a PackedStringArray, + index: usize, +} + +impl<'a, K: AsPrimitive + FromPrimitive + Zero> Iterator for PackedStringIterator<'a, K> { + type Item = &'a str; + + fn next(&mut self) -> Option { + let item = self.array.get(self.index)?; + self.index += 1; + Some(item) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.array.len() - self.index; + (len, Some(len)) + } +} + +#[cfg(test)] +mod tests { + use crate::string::PackedStringArray; + + use proptest::prelude::*; + + #[test] + fn test_storage() { + let mut array = PackedStringArray::::new(); + + array.append("hello"); + array.append("world"); + array.append("cupcake"); + + assert_eq!(array.get(0).unwrap(), "hello"); + assert_eq!(array.get(1).unwrap(), "world"); + assert_eq!(array.get(2).unwrap(), "cupcake"); + assert!(array.get(-1_i32 as usize).is_none()); + + assert!(array.get(3).is_none()); + + array.extend(2); + assert_eq!(array.get(3).unwrap(), ""); + assert_eq!(array.get(4).unwrap(), ""); + assert!(array.get(5).is_none()); + } + + #[test] + fn test_empty() { + let array = PackedStringArray::::new_empty(20); + assert_eq!(array.get(12).unwrap(), ""); + assert_eq!(array.get(9).unwrap(), ""); + assert_eq!(array.get(3).unwrap(), ""); + } + + #[test] + fn test_truncate() { + let mut array = PackedStringArray::::new(); + + array.append("hello"); + array.append("world"); + array.append("cupcake"); + + array.truncate(1); + assert_eq!(array.len(), 1); + assert_eq!(array.get(0).unwrap(), "hello"); + + array.append("world"); + assert_eq!(array.len(), 2); + assert_eq!(array.get(0).unwrap(), "hello"); + assert_eq!(array.get(1).unwrap(), "world"); + } + + #[test] + fn test_extend_from() { + let mut a = PackedStringArray::::new(); + + a.append("hello"); + a.append("world"); + a.append("cupcake"); + a.append(""); + + let mut b = PackedStringArray::::new(); + + b.append("foo"); + b.append("bar"); + + a.extend_from(&b); + + let a_content: Vec<_> = a.iter().collect(); + assert_eq!( + a_content, + vec!["hello", "world", "cupcake", "", "foo", "bar"] + ); + } + + #[test] + fn test_extend_from_range() { + let mut a = PackedStringArray::::new(); + + a.append("hello"); + a.append("world"); + a.append("cupcake"); + a.append(""); + + let mut b = PackedStringArray::::new(); + + b.append("foo"); + b.append("bar"); + b.append(""); + b.append("fiz"); + + a.extend_from_range(&b, 1..3); + + assert_eq!(a.len(), 6); + + let a_content: Vec<_> = a.iter().collect(); + assert_eq!(a_content, vec!["hello", "world", "cupcake", "", "bar", ""]); + + // Should be a no-op + a.extend_from_range(&b, 0..0); + + let a_content: Vec<_> = a.iter().collect(); + assert_eq!(a_content, vec!["hello", "world", "cupcake", "", "bar", ""]); + + a.extend_from_range(&b, 0..1); + + let a_content: Vec<_> = a.iter().collect(); + assert_eq!( + a_content, + vec!["hello", "world", "cupcake", "", "bar", "", "foo"] + ); + + a.extend_from_range(&b, 1..4); + + let a_content: Vec<_> = a.iter().collect(); + assert_eq!( + a_content, + vec!["hello", "world", "cupcake", "", "bar", "", "foo", "bar", "", "fiz"] + ); + } + + proptest! { + #[test] + fn prop_split_off( + a in prop::collection::vec(any::(), 0..20), + b in prop::collection::vec(any::(), 0..20), + ) { + let mut p = PackedStringArray::::new(); + + // Add all the elements in "a" and "b" to the string array. + for v in a.iter().chain(b.iter()) { + p.append(v); + } + + // Split the packed string array at the boundary of "a". + let p2 = p.split_off(a.len()); + + assert_eq!(p.iter().collect::>(), a, "parent"); + assert_eq!(p2.iter().collect::>(), b, "child"); + } + } + + #[test] + fn test_split_off_oob() { + let mut p = PackedStringArray::::new(); + + p.append("bananas"); + + let got = p.split_off(42); + assert_eq!(p.len(), 1); + assert_eq!(got.len(), 0); + } +} diff --git a/arrow_util/src/test_util.rs b/arrow_util/src/test_util.rs new file mode 100644 index 0000000..8126e25 --- /dev/null +++ b/arrow_util/src/test_util.rs @@ -0,0 +1,419 @@ +//! A collection of testing functions for arrow based code +use std::sync::Arc; + +use crate::display::pretty_format_batches; +use arrow::{ + array::{new_null_array, ArrayRef, StringArray}, + compute::kernels::sort::{lexsort, SortColumn, SortOptions}, + datatypes::Schema, + error::ArrowError, + record_batch::RecordBatch, +}; +use once_cell::sync::Lazy; +use regex::{Captures, Regex}; +use std::{borrow::Cow, collections::HashMap}; +use uuid::Uuid; + +/// Compares the formatted output with the pretty formatted results of +/// record batches. This is a macro so errors appear on the correct line +/// +/// Designed so that failure output can be directly copy/pasted +/// into the test code as expected results. +/// +/// Expects to be called about like this: +/// assert_batches_eq(expected_lines: &[&str], chunks: &[RecordBatch]) +#[macro_export] +macro_rules! assert_batches_eq { + ($EXPECTED_LINES: expr, $CHUNKS: expr) => { + let expected_lines: Vec = + $EXPECTED_LINES.into_iter().map(|s| s.to_string()).collect(); + + let actual_lines = arrow_util::test_util::batches_to_lines($CHUNKS); + + assert_eq!( + expected_lines, actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); + }; +} + +/// Compares formatted output of a record batch with an expected +/// vector of strings in a way that order does not matter. +/// This is a macro so errors appear on the correct line +/// +/// Designed so that failure output can be directly copy/pasted +/// into the test code as expected results. +/// +/// Expects to be called about like this: +/// +/// `assert_batch_sorted_eq!(expected_lines: &[&str], batches: &[RecordBatch])` +#[macro_export] +macro_rules! assert_batches_sorted_eq { + ($EXPECTED_LINES: expr, $CHUNKS: expr) => { + let expected_lines: Vec = $EXPECTED_LINES.iter().map(|&s| s.into()).collect(); + let expected_lines = arrow_util::test_util::sort_lines(expected_lines); + + let actual_lines = arrow_util::test_util::batches_to_sorted_lines($CHUNKS); + + assert_eq!( + expected_lines, actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); + }; +} + +/// Converts the [`RecordBatch`]es into a pretty printed output suitable for +/// comparing in tests +/// +/// Example: +/// +/// ```text +/// "+-----+------+------+--------------------------------+", +/// "| foo | host | load | time |", +/// "+-----+------+------+--------------------------------+", +/// "| | a | 1.0 | 1970-01-01T00:00:00.000000011Z |", +/// "| | a | 14.0 | 1970-01-01T00:00:00.000010001Z |", +/// "| | a | 3.0 | 1970-01-01T00:00:00.000000033Z |", +/// "| | b | 5.0 | 1970-01-01T00:00:00.000000011Z |", +/// "| | z | 0.0 | 1970-01-01T00:00:00Z |", +/// "+-----+------+------+--------------------------------+", +/// ``` +pub fn batches_to_lines(batches: &[RecordBatch]) -> Vec { + crate::display::pretty_format_batches(batches) + .unwrap() + .trim() + .lines() + .map(|s| s.to_string()) + .collect() +} + +/// Converts the [`RecordBatch`]es into a pretty printed output suitable for +/// comparing in tests where sorting does not matter. +pub fn batches_to_sorted_lines(batches: &[RecordBatch]) -> Vec { + sort_lines(batches_to_lines(batches)) +} + +/// Sorts the lines (assumed to be the output of `batches_to_lines` for stable comparison) +pub fn sort_lines(mut lines: Vec) -> Vec { + // sort except for header + footer + let num_lines = lines.len(); + if num_lines > 3 { + lines.as_mut_slice()[2..num_lines - 1].sort_unstable() + } + lines +} + +// sort a record batch by all columns (to provide a stable output order for test +// comparison) +pub fn sort_record_batch(batch: RecordBatch) -> RecordBatch { + let sort_input: Vec = batch + .columns() + .iter() + .map(|col| SortColumn { + values: Arc::clone(col), + options: Some(SortOptions { + descending: false, + nulls_first: false, + }), + }) + .collect(); + + let sort_output = lexsort(&sort_input, None).expect("Sorting to complete"); + + RecordBatch::try_new(batch.schema(), sort_output).unwrap() +} + +/// Return a new `StringArray` where each element had a normalization +/// function `norm` applied. +pub fn normalize_string_array(arr: &StringArray, norm: N) -> ArrayRef +where + N: Fn(&str) -> String, +{ + let normalized: StringArray = arr.iter().map(|s| s.map(&norm)).collect(); + Arc::new(normalized) +} + +/// Return a new set of `RecordBatch`es where the function `norm` has +/// applied to all `StringArray` rows. +pub fn normalize_batches(batches: Vec, norm: N) -> Vec +where + N: Fn(&str) -> String, +{ + // The idea here is is to get a function that normalizes strings + // and apply it to each StringArray element by element + batches + .into_iter() + .map(|batch| { + let new_columns: Vec<_> = batch + .columns() + .iter() + .map(|array| { + if let Some(array) = array.as_any().downcast_ref::() { + normalize_string_array(array, &norm) + } else { + Arc::clone(array) + } + }) + .collect(); + + RecordBatch::try_new(batch.schema(), new_columns) + .expect("error occurred during normalization") + }) + .collect() +} + +/// Equalize batch schemas by creating NULL columns. +pub fn equalize_batch_schemas(batches: Vec) -> Result, ArrowError> { + let common_schema = Arc::new(Schema::try_merge( + batches.iter().map(|batch| batch.schema().as_ref().clone()), + )?); + + Ok(batches + .into_iter() + .map(|batch| { + let batch_schema = batch.schema(); + let columns = common_schema + .fields() + .iter() + .map(|field| match batch_schema.index_of(field.name()) { + Ok(idx) => Arc::clone(batch.column(idx)), + Err(_) => new_null_array(field.data_type(), batch.num_rows()), + }) + .collect(); + RecordBatch::try_new(Arc::clone(&common_schema), columns).unwrap() + }) + .collect()) +} + +/// Match the parquet UUID +/// +/// For example, given +/// `32/51/216/13452/1d325760-2b20-48de-ab48-2267b034133d.parquet` +/// +/// matches `1d325760-2b20-48de-ab48-2267b034133d` +pub static REGEX_UUID: Lazy = Lazy::new(|| { + Regex::new("[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}").expect("UUID regex") +}); + +/// Match the parquet directory names +/// For example, given +/// `51/216/1a3f45021a3f45021a3f45021a3f45021a3f45021a3f45021a3f45021a3f4502/1d325760-2b20-48de-ab48-2267b034133d.parquet` +/// +/// matches `51/216/1a3f45021a3f45021a3f45021a3f45021a3f45021a3f45021a3f45021a3f4502` +static REGEX_DIRS: Lazy = + Lazy::new(|| Regex::new(r#"[0-9]+/[0-9]+/[0-9a-f]+/"#).expect("directory regex")); + +/// Replace table row separators of flexible width with fixed with. This is required +/// because the original timing values may differ in "printed width", so the table +/// cells have different widths and hence the separators / borders. E.g.: +/// +/// `+--+--+` -> `----------` +/// `+--+------+` -> `----------` +/// +/// Note that we're kinda inexact with our regex here, but it gets the job done. +static REGEX_LINESEP: Lazy = Lazy::new(|| Regex::new(r#"[+-]{6,}"#).expect("linesep regex")); + +/// Similar to the row separator issue above, the table columns are right-padded +/// with spaces. Due to the different "printed width" of the timing values, we need +/// to normalize this padding as well. E.g.: +/// +/// ` |` -> ` |` +/// ` |` -> ` |` +static REGEX_COL: Lazy = Lazy::new(|| Regex::new(r"\s+\|").expect("col regex")); + +/// Matches line like `metrics=[foo=1, bar=2]` +static REGEX_METRICS: Lazy = + Lazy::new(|| Regex::new(r"metrics=\[([^\]]*)\]").expect("metrics regex")); + +/// Matches things like `1s`, `1.2ms` and `10.2μs` +static REGEX_TIMING: Lazy = + Lazy::new(|| Regex::new(r"[0-9]+(\.[0-9]+)?.s").expect("timing regex")); + +/// Matches things like `FilterExec: .*` and `ParquetExec: .*` +/// +/// Should be used in combination w/ [`REGEX_TIME_OP`]. +static REGEX_FILTER: Lazy = Lazy::new(|| { + Regex::new("(?P(FilterExec)|(ParquetExec): )(?P.*)").expect("filter regex") +}); + +/// Matches things like `time@3 < -9223372036854775808` and `time_min@2 > 1641031200399937022` +static REGEX_TIME_OP: Lazy = Lazy::new(|| { + Regex::new("(?Ptime((_min)|(_max))?@[0-9]+ [<>=]=? (CAST\\()?)(?P-?[0-9]+)(?P AS Timestamp\\(Nanosecond, \"[^\"]\"\\)\\))?") + .expect("time opt regex") +}); + +fn normalize_for_variable_width(s: Cow<'_, str>) -> String { + let s = REGEX_LINESEP.replace_all(&s, "----------"); + REGEX_COL.replace_all(&s, " |").to_string() +} + +pub fn strip_table_lines(s: Cow<'_, str>) -> String { + let s = REGEX_LINESEP.replace_all(&s, "----------"); + REGEX_COL.replace_all(&s, "").to_string() +} + +fn normalize_time_ops(s: &str) -> String { + REGEX_TIME_OP + .replace_all(s, |c: &Captures<'_>| { + let prefix = c.name("prefix").expect("always captures").as_str(); + let suffix = c.name("suffix").map_or("", |m| m.as_str()); + format!("{prefix}{suffix}") + }) + .to_string() +} + +/// A query to run with optional annotations +#[derive(Debug, PartialEq, Eq, Default, Clone, Copy)] +pub struct Normalizer { + /// If true, results are sorted first + pub sorted_compare: bool, + + /// If true, replace UUIDs with static placeholders. + pub normalized_uuids: bool, + + /// If true, normalize timings in queries by replacing them with + /// static placeholders, for example: + /// + /// `1s` -> `1.234ms` + pub normalized_metrics: bool, + + /// if true, normalize filter predicates for explain plans + /// `FilterExec: ` + pub normalized_filters: bool, + + /// if `true`, render tables without borders. + pub no_table_borders: bool, +} + +impl Normalizer { + pub fn new() -> Self { + Default::default() + } + + /// Take the output of running the query and apply the specified normalizations to them + pub fn normalize_results(&self, mut results: Vec) -> Vec { + // compare against sorted results, if requested + if self.sorted_compare && !results.is_empty() { + let schema = results[0].schema(); + let batch = + arrow::compute::concat_batches(&schema, &results).expect("concatenating batches"); + results = vec![sort_record_batch(batch)]; + } + + let mut current_results = pretty_format_batches(&results) + .unwrap() + .trim() + .lines() + .map(|s| s.to_string()) + .collect::>(); + + // normalize UUIDs, if requested + if self.normalized_uuids { + let mut seen: HashMap = HashMap::new(); + current_results = current_results + .into_iter() + .map(|s| { + // Rewrite Parquet directory names like + // `51/216/1a3f45021a3f45021a3f45021a3f45021a3f45021a3f45021a3f45021a3f4502/1d325760-2b20-48de-ab48-2267b034133d.parquet` + // + // to: + // 1/1/1/00000000-0000-0000-0000-000000000000.parquet + + let s = REGEX_UUID.replace_all(&s, |s: &Captures<'_>| { + let next = seen.len() as u128; + Uuid::from_u128( + *seen + .entry(s.get(0).unwrap().as_str().to_owned()) + .or_insert(next), + ) + .to_string() + }); + + let s = normalize_for_variable_width(s); + REGEX_DIRS.replace_all(&s, "1/1/1/").to_string() + }) + .collect(); + } + + // normalize metrics, if requested + if self.normalized_metrics { + current_results = current_results + .into_iter() + .map(|s| { + // Replace timings with fixed value, e.g.: + // + // `1s` -> `1.234ms` + // `1.2ms` -> `1.234ms` + // `10.2μs` -> `1.234ms` + let s = REGEX_TIMING.replace_all(&s, "1.234ms"); + + let s = normalize_for_variable_width(s); + + // Metrics are currently ordered by value (not by key), so different timings may + // reorder them. We "parse" the list and normalize the sorting. E.g.: + // + // `metrics=[]` => `metrics=[]` + // `metrics=[foo=1, bar=2]` => `metrics=[bar=2, foo=1]` + // `metrics=[foo=2, bar=1]` => `metrics=[bar=1, foo=2]` + REGEX_METRICS + .replace_all(&s, |c: &Captures<'_>| { + let mut metrics: Vec<_> = c[1].split(", ").collect(); + metrics.sort(); + format!("metrics=[{}]", metrics.join(", ")) + }) + .to_string() + }) + .collect(); + } + + // normalize Filters, if requested + // + // Converts: + // FilterExec: time@2 < -9223372036854775808 OR time@2 > 1640995204240217000 + // ParquetExec: limit=None, partitions={...}, predicate=time@2 > 1640995204240217000, pruning_predicate=time@2 > 1640995204240217000, output_ordering=[...], projection=[...] + // + // to + // FilterExec: time@2 < OR time@2 > + // ParquetExec: limit=None, partitions={...}, predicate=time@2 > , pruning_predicate=time@2 > , output_ordering=[...], projection=[...] + if self.normalized_filters { + current_results = current_results + .into_iter() + .map(|s| { + REGEX_FILTER + .replace_all(&s, |c: &Captures<'_>| { + let prefix = c.name("prefix").expect("always captrues").as_str(); + + let expr = c.name("expr").expect("always captures").as_str(); + let expr = normalize_time_ops(expr); + + format!("{prefix}{expr}") + }) + .to_string() + }) + .collect(); + } + + current_results + } + + /// Adds information on what normalizations were applied to the input + pub fn add_description(&self, output: &mut Vec) { + if self.sorted_compare { + output.push("-- Results After Sorting".into()) + } + if self.normalized_uuids { + output.push("-- Results After Normalizing UUIDs".into()) + } + if self.normalized_metrics { + output.push("-- Results After Normalizing Metrics".into()) + } + if self.normalized_filters { + output.push("-- Results After Normalizing Filters".into()) + } + if self.no_table_borders { + output.push("-- Results After No Table Borders".into()) + } + } +} diff --git a/arrow_util/src/util.rs b/arrow_util/src/util.rs new file mode 100644 index 0000000..3677dd0 --- /dev/null +++ b/arrow_util/src/util.rs @@ -0,0 +1,57 @@ +//! Utility functions for working with arrow + +use std::iter::FromIterator; +use std::sync::Arc; + +use arrow::{ + array::{new_null_array, ArrayRef, StringArray}, + datatypes::SchemaRef, + error::ArrowError, + record_batch::RecordBatch, +}; + +/// Returns a single column record batch of type Utf8 from the +/// contents of something that can be turned into an iterator over +/// `Option<&str>` +pub fn str_iter_to_batch(field_name: &str, iter: I) -> Result +where + I: IntoIterator>, + Ptr: AsRef, +{ + let array = StringArray::from_iter(iter); + + RecordBatch::try_from_iter(vec![(field_name, Arc::new(array) as ArrayRef)]) +} + +/// Ensures the record batch has the specified schema +pub fn ensure_schema( + output_schema: &SchemaRef, + batch: &RecordBatch, +) -> Result { + let batch_schema = batch.schema(); + + // Go over all columns of output_schema + let batch_output_columns = output_schema + .fields() + .iter() + .map(|output_field| { + // See if the output_field available in the batch + let batch_field_index = batch_schema + .fields() + .iter() + .enumerate() + .find(|(_, batch_field)| output_field.name() == batch_field.name()) + .map(|(idx, _)| idx); + + if let Some(batch_field_index) = batch_field_index { + // The column available, use it + Arc::clone(batch.column(batch_field_index)) + } else { + // the column not available, add it with all null values + new_null_array(output_field.data_type(), batch.num_rows()) + } + }) + .collect::>(); + + RecordBatch::try_new(Arc::clone(output_schema), batch_output_columns) +} diff --git a/authz/Cargo.toml b/authz/Cargo.toml new file mode 100644 index 0000000..9fc5ed9 --- /dev/null +++ b/authz/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "authz" +description = "Interface to authorization checking services" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[lints] +workspace = true + +[dependencies] +backoff = { path = "../backoff" } +http = {version = "0.2.11", optional = true } +iox_time = { version = "0.1.0", path = "../iox_time" } +generated_types = { path = "../generated_types" } +metric = { version = "0.1.0", path = "../metric" } +observability_deps = { path = "../observability_deps" } +workspace-hack = { version = "0.1", path = "../workspace-hack" } + +# crates.io dependencies in alphabetical order. +async-trait = "0.1" +base64 = "0.21.7" +snafu = "0.8" +tonic = { workspace = true } + +[dev-dependencies] +assert_matches = "1.5.0" +parking_lot = "0.12.1" +paste = "1.0.14" +test_helpers_end_to_end = { path = "../test_helpers_end_to_end" } +tokio = "1.35.1" + +[features] +http = ["dep:http"] diff --git a/authz/src/authorizer.rs b/authz/src/authorizer.rs new file mode 100644 index 0000000..488ceb5 --- /dev/null +++ b/authz/src/authorizer.rs @@ -0,0 +1,88 @@ +use std::ops::ControlFlow; + +use async_trait::async_trait; +use backoff::{Backoff, BackoffConfig}; + +use super::{Error, Permission}; + +/// An authorizer is used to validate a request +/// (+ associated permissions needed to fulfill the request) +/// with an authorization token that has been extracted from the request. +#[async_trait] +pub trait Authorizer: std::fmt::Debug + Send + Sync { + /// Determine the permissions associated with a request token. + /// + /// The returned list of permissions is the intersection of the permissions + /// requested and the permissions associated with the token. + /// + /// Implementations of this trait should return the specified errors under + /// the following conditions: + /// + /// * [`Error::InvalidToken`]: the token is invalid / in an incorrect + /// format / otherwise corrupt and a permission check cannot be + /// performed + /// + /// * [`Error::NoToken`]: the token was not provided + /// + /// * [`Error::Forbidden`]: the token was well formed, but lacks + /// authorisation to perform the requested action + /// + /// * [`Error::Verification`]: the token permissions were not possible + /// to validate - an internal error has occurred + async fn permissions( + &self, + token: Option>, + perms: &[Permission], + ) -> Result, Error>; + + /// Make a test request that determines if end-to-end communication + /// with the service is working. + /// + /// Test is performed during deployment, with ordering of availability not being guaranteed. + async fn probe(&self) -> Result<(), Error> { + Backoff::new(&BackoffConfig::default()) + .retry_with_backoff("probe iox-authz service", move || { + async { + match self.permissions(Some(b"".to_vec()), &[]).await { + // got response from authorizer server + Ok(_) + | Err(Error::Forbidden) + | Err(Error::InvalidToken) + | Err(Error::NoToken) => ControlFlow::Break(Ok(())), + // communication error == Error::Verification + Err(e) => ControlFlow::<_, Error>::Continue(e), + } + } + }) + .await + .expect("retry forever") + } +} + +/// Wrapped `Option` +/// Provides response to inner `IoxAuthorizer::permissions()` +#[async_trait] +impl Authorizer for Option { + async fn permissions( + &self, + token: Option>, + perms: &[Permission], + ) -> Result, Error> { + match self { + Some(authz) => authz.permissions(token, perms).await, + // no authz rpc service => return same perms requested. Used for testing. + None => Ok(perms.to_vec()), + } + } +} + +#[async_trait] +impl + std::fmt::Debug + Send + Sync> Authorizer for T { + async fn permissions( + &self, + token: Option>, + perms: &[Permission], + ) -> Result, Error> { + self.as_ref().permissions(token, perms).await + } +} diff --git a/authz/src/http.rs b/authz/src/http.rs new file mode 100644 index 0000000..e45b37e --- /dev/null +++ b/authz/src/http.rs @@ -0,0 +1,29 @@ +//! HTTP authorisation helpers. + +use http::HeaderValue; + +/// We strip off the "authorization" header from the request, to prevent it from being accidentally logged +/// and we put it in an extension of the request. Extensions are typed and this is the typed wrapper that +/// holds an (optional) authorization header value. +pub struct AuthorizationHeaderExtension(Option); + +impl AuthorizationHeaderExtension { + /// Construct new extension wrapper for a possible header value + pub fn new(header: Option) -> Self { + Self(header) + } +} + +impl std::fmt::Debug for AuthorizationHeaderExtension { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("AuthorizationHeaderExtension(...)") + } +} + +impl std::ops::Deref for AuthorizationHeaderExtension { + type Target = Option; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/authz/src/instrumentation.rs b/authz/src/instrumentation.rs new file mode 100644 index 0000000..b64fc2c --- /dev/null +++ b/authz/src/instrumentation.rs @@ -0,0 +1,248 @@ +use async_trait::async_trait; +use iox_time::{SystemProvider, TimeProvider}; +use metric::{DurationHistogram, Metric, Registry}; + +use super::{Authorizer, Error, Permission}; + +const AUTHZ_DURATION_METRIC: &str = "authz_permission_check_duration"; + +/// An instrumentation decorator over a [`Authorizer`] implementation. +/// +/// This wrapper captures the latency distribution of the decorated +/// [`Authorizer::permissions()`] call, faceted by success/error result. +#[derive(Debug)] +pub struct AuthorizerInstrumentation { + inner: T, + time_provider: P, + + /// Permissions-check duration distribution for successesful rpc, but not authorized. + ioxauth_rpc_duration_success_unauth: DurationHistogram, + + /// Permissions-check duration distribution for successesful rpc + authorized. + ioxauth_rpc_duration_success_auth: DurationHistogram, + + /// Permissions-check duration distribution for errors. + ioxauth_rpc_duration_error: DurationHistogram, +} + +impl AuthorizerInstrumentation { + /// Record permissions-check duration metrics, broken down by result. + pub fn new(registry: &Registry, inner: T) -> Self { + let metric: Metric = + registry.register_metric(AUTHZ_DURATION_METRIC, "duration of authz permissions check"); + + let ioxauth_rpc_duration_success_unauth = + metric.recorder(&[("result", "success"), ("auth_state", "unauthorised")]); + let ioxauth_rpc_duration_success_auth = + metric.recorder(&[("result", "success"), ("auth_state", "authorised")]); + let ioxauth_rpc_duration_error = + metric.recorder(&[("result", "error"), ("auth_state", "unauthorised")]); + + Self { + inner, + time_provider: Default::default(), + ioxauth_rpc_duration_success_unauth, + ioxauth_rpc_duration_success_auth, + ioxauth_rpc_duration_error, + } + } +} + +#[async_trait] +impl Authorizer for AuthorizerInstrumentation +where + T: Authorizer, +{ + async fn permissions( + &self, + token: Option>, + perms: &[Permission], + ) -> Result, Error> { + let t = self.time_provider.now(); + let res = self.inner.permissions(token, perms).await; + + if let Some(delta) = self.time_provider.now().checked_duration_since(t) { + match &res { + Ok(_) => self.ioxauth_rpc_duration_success_auth.record(delta), + Err(Error::Forbidden) | Err(Error::InvalidToken) => { + self.ioxauth_rpc_duration_success_unauth.record(delta) + } + Err(Error::Verification { .. }) => self.ioxauth_rpc_duration_error.record(delta), + Err(Error::NoToken) => {} // rpc was not made + }; + } + + res + } +} + +#[cfg(test)] +mod test { + use std::collections::VecDeque; + + use metric::{assert_histogram, Attributes, Registry}; + use parking_lot::Mutex; + + use super::*; + use crate::{Action, Resource}; + + #[derive(Debug, Default)] + struct MockAuthorizerState { + ret: VecDeque, Error>>, + } + + #[derive(Debug, Default)] + struct MockAuthorizer { + state: Mutex, + } + + impl MockAuthorizer { + pub(crate) fn with_permissions_return( + self, + ret: impl Into, Error>>>, + ) -> Self { + self.state.lock().ret = ret.into(); + self + } + } + + #[async_trait] + impl Authorizer for MockAuthorizer { + async fn permissions( + &self, + _token: Option>, + _perms: &[Permission], + ) -> Result, Error> { + self.state + .lock() + .ret + .pop_front() + .expect("no mock sink value to return") + } + } + + macro_rules! assert_metric_counts { + ( + $metrics:ident, + expected_success = $expected_success:expr, + expected_rpc_success_unauth = $expected_rpc_success_unauth:expr, + expected_rpc_failures = $expected_rpc_failures:expr, + ) => { + let histogram = &$metrics + .get_instrument::>(AUTHZ_DURATION_METRIC) + .expect("failed to read metric"); + + let success_labels = + Attributes::from(&[("result", "success"), ("auth_state", "authorised")]); + let histogram_success = &histogram + .get_observer(&success_labels) + .expect("failed to find metric with provided attributes") + .fetch(); + + assert_histogram!( + $metrics, + DurationHistogram, + AUTHZ_DURATION_METRIC, + labels = success_labels, + samples = $expected_success, + sum = histogram_success.total, + ); + + let success_unauth_labels = + Attributes::from(&[("result", "success"), ("auth_state", "unauthorised")]); + let histogram_success_unauth = &histogram + .get_observer(&success_unauth_labels) + .expect("failed to find metric with provided attributes") + .fetch(); + + assert_histogram!( + $metrics, + DurationHistogram, + AUTHZ_DURATION_METRIC, + labels = success_unauth_labels, + samples = $expected_rpc_success_unauth, + sum = histogram_success_unauth.total, + ); + + let rpc_error_labels = + Attributes::from(&[("result", "error"), ("auth_state", "unauthorised")]); + let histogram_rpc_error = &histogram + .get_observer(&rpc_error_labels) + .expect("failed to find metric with provided attributes") + .fetch(); + + assert_histogram!( + $metrics, + DurationHistogram, + AUTHZ_DURATION_METRIC, + labels = rpc_error_labels, + samples = $expected_rpc_failures, + sum = histogram_rpc_error.total, + ); + }; + } + + macro_rules! test_authorizer_metric { + ( + $name:ident, + rpc_response = $rpc_response:expr, + will_pass_auth = $will_pass_auth:expr, + expected_success_cnt = $expected_success_cnt:expr, + expected_success_unauth_cnt = $expected_success_unauth_cnt:expr, + expected_error_cnt = $expected_error_cnt:expr, + ) => { + paste::paste! { + #[tokio::test] + async fn []() { + let metrics = Registry::default(); + let decorated_authz = AuthorizerInstrumentation::new( + &metrics, + MockAuthorizer::default().with_permissions_return([$rpc_response]) + ); + + let token = "any".as_bytes().to_vec(); + let got = decorated_authz + .permissions(Some(token), &[]) + .await; + assert_eq!(got.is_ok(), $will_pass_auth); + assert_metric_counts!( + metrics, + expected_success = $expected_success_cnt, + expected_rpc_success_unauth = $expected_success_unauth_cnt, + expected_rpc_failures = $expected_error_cnt, + ); + } + } + }; + } + + test_authorizer_metric!( + ok, + rpc_response = Ok(vec![Permission::ResourceAction( + Resource::Database("foo".to_string()), + Action::Write, + )]), + will_pass_auth = true, + expected_success_cnt = 1, + expected_success_unauth_cnt = 0, + expected_error_cnt = 0, + ); + + test_authorizer_metric!( + will_record_failure_if_rpc_fails, + rpc_response = Err(Error::verification("test", "test error")), + will_pass_auth = false, + expected_success_cnt = 0, + expected_success_unauth_cnt = 0, + expected_error_cnt = 1, + ); + + test_authorizer_metric!( + will_record_success_if_rpc_pass_but_auth_fails, + rpc_response = Err(Error::InvalidToken), + will_pass_auth = false, + expected_success_cnt = 0, + expected_success_unauth_cnt = 1, + expected_error_cnt = 0, + ); +} diff --git a/authz/src/iox_authorizer.rs b/authz/src/iox_authorizer.rs new file mode 100644 index 0000000..7228d3d --- /dev/null +++ b/authz/src/iox_authorizer.rs @@ -0,0 +1,309 @@ +use async_trait::async_trait; +use generated_types::influxdata::iox::authz::v1::{self as proto, AuthorizeResponse}; +use observability_deps::tracing::warn; +use snafu::Snafu; +use tonic::Response; + +use super::{Authorizer, Permission}; + +/// Authorizer implementation using influxdata.iox.authz.v1 protocol. +#[derive(Clone, Debug)] +pub struct IoxAuthorizer { + client: + proto::iox_authorizer_service_client::IoxAuthorizerServiceClient, +} + +impl IoxAuthorizer { + /// Attempt to create a new client by connecting to a given endpoint. + pub fn connect_lazy(dst: D) -> Result> + where + D: TryInto + Send, + D::Error: Into, + { + let ep = tonic::transport::Endpoint::new(dst)?; + let client = proto::iox_authorizer_service_client::IoxAuthorizerServiceClient::new( + ep.connect_lazy(), + ); + Ok(Self { client }) + } + + async fn request( + &self, + token: Vec, + requested_perms: &[Permission], + ) -> Result, tonic::Status> { + let req = proto::AuthorizeRequest { + token, + permissions: requested_perms + .iter() + .filter_map(|p| p.clone().try_into().ok()) + .collect(), + }; + let mut client = self.client.clone(); + client.authorize(req).await + } +} + +#[async_trait] +impl Authorizer for IoxAuthorizer { + async fn permissions( + &self, + token: Option>, + requested_perms: &[Permission], + ) -> Result, Error> { + let authz_rpc_result = self + .request(token.ok_or(Error::NoToken)?, requested_perms) + .await + .map_err(|status| Error::Verification { + msg: status.message().to_string(), + source: Box::new(status), + })? + .into_inner(); + + if !authz_rpc_result.valid { + return Err(Error::InvalidToken); + } + + let intersected_perms: Vec = authz_rpc_result + .permissions + .into_iter() + .filter_map(|p| match p.try_into() { + Ok(p) => Some(p), + Err(e) => { + warn!(error=%e, "authz service returned incompatible permission"); + None + } + }) + .collect(); + + if intersected_perms.is_empty() { + return Err(Error::Forbidden); + } + Ok(intersected_perms) + } +} + +/// Authorization related error. +#[derive(Debug, Snafu)] +pub enum Error { + /// Communication error when verifying a token. + #[snafu(display("token verification not possible: {msg}"))] + Verification { + /// Message describing the error. + msg: String, + /// Source of the error. + source: Box, + }, + + /// Token is invalid. + #[snafu(display("invalid token"))] + InvalidToken, + + /// The token's permissions do not allow the operation. + #[snafu(display("forbidden"))] + Forbidden, + + /// No token has been supplied, but is required. + #[snafu(display("no token"))] + NoToken, +} + +impl Error { + /// Create new Error::Verification. + pub fn verification( + msg: impl Into, + source: impl Into>, + ) -> Self { + Self::Verification { + msg: msg.into(), + source: source.into(), + } + } +} + +impl From for Error { + fn from(value: tonic::Status) -> Self { + Self::verification(value.message(), value.clone()) + } +} + +#[cfg(test)] +mod test { + use std::{ + net::SocketAddr, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::Duration, + }; + + use assert_matches::assert_matches; + use test_helpers_end_to_end::Authorizer as AuthorizerServer; + use tokio::{ + net::TcpListener, + task::{spawn, JoinHandle}, + }; + use tonic::transport::{server::TcpIncoming, Server}; + + use super::*; + use crate::{Action, Authorizer, Permission, Resource}; + + const NAMESPACE: &str = "bananas"; + + macro_rules! test_iox_authorizer { + ( + $name:ident, + token_permissions = $token_permissions:expr, + permissions_required = $permissions_required:expr, + want = $want:pat + ) => { + paste::paste! { + #[tokio::test] + async fn []() { + let mut authz_server = AuthorizerServer::create().await; + let authz = IoxAuthorizer::connect_lazy(authz_server.addr()) + .expect("Failed to create IoxAuthorizer client."); + + let token = authz_server.create_token_for(NAMESPACE, $token_permissions); + + let got = authz.permissions( + Some(token.as_bytes().to_vec()), + $permissions_required + ).await; + + assert_matches!(got, $want); + } + } + }; + } + + test_iox_authorizer!( + ok, + token_permissions = &["ACTION_WRITE"], + permissions_required = &[Permission::ResourceAction( + Resource::Database(NAMESPACE.to_string()), + Action::Write, + )], + want = Ok(_) + ); + + test_iox_authorizer!( + insufficient_perms, + token_permissions = &["ACTION_READ"], + permissions_required = &[Permission::ResourceAction( + Resource::Database(NAMESPACE.to_string()), + Action::Write, + )], + want = Err(Error::Forbidden) + ); + + test_iox_authorizer!( + any_of_required_perms, + token_permissions = &["ACTION_WRITE"], + permissions_required = &[ + Permission::ResourceAction(Resource::Database(NAMESPACE.to_string()), Action::Write,), + Permission::ResourceAction(Resource::Database(NAMESPACE.to_string()), Action::Create,) + ], + want = Ok(_) + ); + + #[tokio::test] + async fn test_invalid_token() { + let authz_server = AuthorizerServer::create().await; + let authz = IoxAuthorizer::connect_lazy(authz_server.addr()) + .expect("Failed to create IoxAuthorizer client."); + + let invalid_token = b"UGLY"; + + let got = authz + .permissions( + Some(invalid_token.to_vec()), + &[Permission::ResourceAction( + Resource::Database(NAMESPACE.to_string()), + Action::Read, + )], + ) + .await; + + assert_matches!(got, Err(Error::InvalidToken)); + } + + #[tokio::test] + async fn test_delayed_probe_response() { + #[derive(Default, Debug)] + struct DelayedAuthorizer(Arc); + + impl DelayedAuthorizer { + fn start_countdown(&self) { + let started = Arc::clone(&self.0); + spawn(async move { + tokio::time::sleep(Duration::from_secs(2)).await; + started.store(true, Ordering::Relaxed); + }); + } + } + + #[async_trait] + impl proto::iox_authorizer_service_server::IoxAuthorizerService for DelayedAuthorizer { + async fn authorize( + &self, + _request: tonic::Request, + ) -> Result, tonic::Status> { + let startup_done = self.0.load(Ordering::Relaxed); + if !startup_done { + return Err(tonic::Status::deadline_exceeded("startup not done")); + } + + Ok(tonic::Response::new(AuthorizeResponse { + valid: true, + subject: None, + permissions: vec![], + })) + } + } + + #[derive(Debug)] + struct DelayedServer { + addr: SocketAddr, + handle: JoinHandle>, + } + + impl DelayedServer { + async fn create() -> Self { + let listener = TcpListener::bind("localhost:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let incoming = TcpIncoming::from_listener(listener, false, None).unwrap(); + + // start countdown mocking startup delay of sidecar + let authz = DelayedAuthorizer::default(); + authz.start_countdown(); + + let router = Server::builder().add_service( + proto::iox_authorizer_service_server::IoxAuthorizerServiceServer::new(authz), + ); + let handle = spawn(router.serve_with_incoming(incoming)); + Self { addr, handle } + } + + fn addr(&self) -> String { + format!("http://{}", self.addr) + } + + fn close(self) { + self.handle.abort(); + } + } + + let authz_server = DelayedServer::create().await; + let authz_client = IoxAuthorizer::connect_lazy(authz_server.addr()) + .expect("Failed to create IoxAuthorizer client."); + + assert_matches!( + authz_client.probe().await, + Ok(()), + "authz probe should work even with delay" + ); + authz_server.close(); + } +} diff --git a/authz/src/lib.rs b/authz/src/lib.rs new file mode 100644 index 0000000..7b2fd54 --- /dev/null +++ b/authz/src/lib.rs @@ -0,0 +1,100 @@ +//! IOx authorization client. +//! +//! Authorization client interface to be used by IOx components to +//! restrict access to authorized requests where required. + +#![deny(rustdoc::broken_intra_doc_links, rust_2018_idioms)] +#![warn( + missing_copy_implementations, + missing_docs, + clippy::explicit_iter_loop, + // See https://github.com/influxdata/influxdb_iox/pull/1671 + clippy::future_not_send, + clippy::use_self, + clippy::clone_on_ref_ptr, + clippy::todo, + clippy::dbg_macro, + unused_crate_dependencies +)] +#![allow(rustdoc::private_intra_doc_links)] + +// Workaround for "unused crate" lint false positives. +use workspace_hack as _; + +use base64::{prelude::BASE64_STANDARD, Engine}; +use generated_types::influxdata::iox::authz::v1::{self as proto}; +use observability_deps::tracing::warn; + +mod authorizer; +pub use authorizer::Authorizer; +mod iox_authorizer; +pub use iox_authorizer::{Error, IoxAuthorizer}; +mod instrumentation; +pub use instrumentation::AuthorizerInstrumentation; +mod permission; +pub use permission::{Action, Permission, Resource}; + +#[cfg(feature = "http")] +pub mod http; + +/// Extract a token from an HTTP header or gRPC metadata value. +pub fn extract_token + ?Sized>(value: Option<&T>) -> Option> { + let mut parts = value?.as_ref().splitn(2, |&v| v == b' '); + let token = match parts.next()? { + b"Token" | b"Bearer" => parts.next()?.to_vec(), + b"Basic" => parts + .next() + .and_then(|v| BASE64_STANDARD.decode(v).ok())? + .splitn(2, |&v| v == b':') + .nth(1)? + .to_vec(), + _ => return None, + }; + if token.is_empty() { + None + } else { + Some(token) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn verify_error_from_tonic_status() { + let s = tonic::Status::resource_exhausted("test error"); + let e = Error::from(s); + assert_eq!( + "token verification not possible: test error", + format!("{e}") + ) + } + + #[test] + fn test_extract_token() { + assert_eq!(None, extract_token::<&str>(None)); + assert_eq!(None, extract_token(Some(""))); + assert_eq!(None, extract_token(Some("Basic"))); + assert_eq!(None, extract_token(Some("Basic Og=="))); // ":" + assert_eq!(None, extract_token(Some("Basic dXNlcm5hbWU6"))); // "username:" + assert_eq!(None, extract_token(Some("Basic Og=="))); // ":" + assert_eq!( + Some(b"password".to_vec()), + extract_token(Some("Basic OnBhc3N3b3Jk")) + ); // ":password" + assert_eq!( + Some(b"password2".to_vec()), + extract_token(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQy")) + ); // "username:password2" + assert_eq!(None, extract_token(Some("Bearer"))); + assert_eq!(None, extract_token(Some("Bearer "))); + assert_eq!(Some(b"token".to_vec()), extract_token(Some("Bearer token"))); + assert_eq!(None, extract_token(Some("Token"))); + assert_eq!(None, extract_token(Some("Token "))); + assert_eq!( + Some(b"token2".to_vec()), + extract_token(Some("Token token2")) + ); + } +} diff --git a/authz/src/permission.rs b/authz/src/permission.rs new file mode 100644 index 0000000..9ffced0 --- /dev/null +++ b/authz/src/permission.rs @@ -0,0 +1,310 @@ +use super::proto; +use snafu::Snafu; + +/// Action is the type of operation being attempted on a resource. +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum Action { + /// The create action is used when a new instance of the resource will + /// be created. + Create, + /// The delete action is used when a resource will be deleted. + Delete, + /// The read action is used when the data contained by a resource will + /// be read. + Read, + /// The read-schema action is used when only metadata about a resource + /// will be read. + ReadSchema, + /// The write action is used when data is being written to the resource. + Write, +} + +impl TryFrom for Action { + type Error = IncompatiblePermissionError; + + fn try_from(value: proto::resource_action_permission::Action) -> Result { + match value { + proto::resource_action_permission::Action::ReadSchema => Ok(Self::ReadSchema), + proto::resource_action_permission::Action::Read => Ok(Self::Read), + proto::resource_action_permission::Action::Write => Ok(Self::Write), + proto::resource_action_permission::Action::Create => Ok(Self::Create), + proto::resource_action_permission::Action::Delete => Ok(Self::Delete), + _ => Err(IncompatiblePermissionError {}), + } + } +} + +impl From for proto::resource_action_permission::Action { + fn from(value: Action) -> Self { + match value { + Action::Create => Self::Create, + Action::Delete => Self::Delete, + Action::Read => Self::Read, + Action::ReadSchema => Self::ReadSchema, + Action::Write => Self::Write, + } + } +} + +/// An incompatible-permission-error is the error that is returned if +/// there is an attempt to convert a permssion into a form that is +/// unsupported. For the most part this should not cause an error to +/// be returned to the user, but more as a signal that the conversion +/// can never succeed and therefore the permisison can never be granted. +/// This error will normally be silently dropped along with the source +/// permission that caused it. +#[derive(Clone, Copy, Debug, PartialEq, Snafu)] +#[snafu(display("incompatible permission"))] +pub struct IncompatiblePermissionError {} + +/// A permission is an authorization that can be checked with an +/// authorizer. Not all authorizers neccessarily support all forms of +/// permission. If an authorizer doesn't support a permission then it +/// is not an error, the permission will always be denied. +#[derive(Clone, Debug, PartialEq)] +pub enum Permission { + /// ResourceAction is a permission in the form of a reasource and an + /// action. + ResourceAction(Resource, Action), +} + +impl TryFrom for Permission { + type Error = IncompatiblePermissionError; + + fn try_from(value: proto::Permission) -> Result { + match value.permission_one_of { + Some(proto::permission::PermissionOneOf::ResourceAction(ra)) => { + let r = Resource::try_from_proto( + proto::resource_action_permission::ResourceType::try_from(ra.resource_type) + .map_err(|_| IncompatiblePermissionError {})?, + ra.resource_id, + )?; + let a = Action::try_from( + proto::resource_action_permission::Action::try_from(ra.action) + .map_err(|_| IncompatiblePermissionError {})?, + )?; + Ok(Self::ResourceAction(r, a)) + } + _ => Err(IncompatiblePermissionError {}), + } + } +} + +impl TryFrom for proto::Permission { + type Error = IncompatiblePermissionError; + + fn try_from(value: Permission) -> Result { + match value { + Permission::ResourceAction(r, a) => { + let (rt, ri) = r.try_into_proto()?; + let a: proto::resource_action_permission::Action = a.into(); + Ok(Self { + permission_one_of: Some(proto::permission::PermissionOneOf::ResourceAction( + proto::ResourceActionPermission { + resource_type: rt as i32, + resource_id: ri, + action: a as i32, + }, + )), + }) + } + } + } +} + +/// A resource is the object that a request is trying to access. +#[derive(Clone, Debug, PartialEq)] +pub enum Resource { + /// A database is a named IOx database. + Database(String), +} + +impl Resource { + fn try_from_proto( + rt: proto::resource_action_permission::ResourceType, + ri: Option, + ) -> Result { + match (rt, ri) { + (proto::resource_action_permission::ResourceType::Database, Some(s)) => { + Ok(Self::Database(s)) + } + _ => Err(IncompatiblePermissionError {}), + } + } + + fn try_into_proto( + self, + ) -> Result< + ( + proto::resource_action_permission::ResourceType, + Option, + ), + IncompatiblePermissionError, + > { + match self { + Self::Database(s) => Ok(( + proto::resource_action_permission::ResourceType::Database, + Some(s), + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn action_try_from_proto() { + assert_eq!( + Action::Create, + Action::try_from(proto::resource_action_permission::Action::Create).unwrap(), + ); + assert_eq!( + Action::Delete, + Action::try_from(proto::resource_action_permission::Action::Delete).unwrap(), + ); + assert_eq!( + Action::Read, + Action::try_from(proto::resource_action_permission::Action::Read).unwrap(), + ); + assert_eq!( + Action::ReadSchema, + Action::try_from(proto::resource_action_permission::Action::ReadSchema).unwrap(), + ); + assert_eq!( + Action::Write, + Action::try_from(proto::resource_action_permission::Action::Write).unwrap(), + ); + assert_eq!( + IncompatiblePermissionError {}, + Action::try_from(proto::resource_action_permission::Action::Unspecified).unwrap_err(), + ); + } + + #[test] + fn action_into_proto() { + assert_eq!( + proto::resource_action_permission::Action::Create, + proto::resource_action_permission::Action::from(Action::Create) + ); + assert_eq!( + proto::resource_action_permission::Action::Delete, + proto::resource_action_permission::Action::from(Action::Delete) + ); + assert_eq!( + proto::resource_action_permission::Action::Read, + proto::resource_action_permission::Action::from(Action::Read) + ); + assert_eq!( + proto::resource_action_permission::Action::ReadSchema, + proto::resource_action_permission::Action::from(Action::ReadSchema) + ); + assert_eq!( + proto::resource_action_permission::Action::Write, + proto::resource_action_permission::Action::from(Action::Write) + ); + } + + #[test] + fn resource_try_from_proto() { + assert_eq!( + Resource::Database("ns1".into()), + Resource::try_from_proto( + proto::resource_action_permission::ResourceType::Database, + Some("ns1".into()) + ) + .unwrap() + ); + assert_eq!( + IncompatiblePermissionError {}, + Resource::try_from_proto( + proto::resource_action_permission::ResourceType::Database, + None + ) + .unwrap_err() + ); + assert_eq!( + IncompatiblePermissionError {}, + Resource::try_from_proto( + proto::resource_action_permission::ResourceType::Unspecified, + Some("ns1".into()) + ) + .unwrap_err() + ); + } + + #[test] + fn resource_try_into_proto() { + assert_eq!( + ( + proto::resource_action_permission::ResourceType::Database, + Some("ns1".into()) + ), + Resource::Database("ns1".into()).try_into_proto().unwrap(), + ); + } + + #[test] + fn permission_try_from_proto() { + assert_eq!( + Permission::ResourceAction(Resource::Database("ns2".into()), Action::Create), + Permission::try_from(proto::Permission { + permission_one_of: Some(proto::permission::PermissionOneOf::ResourceAction( + proto::ResourceActionPermission { + resource_type: 1, + resource_id: Some("ns2".into()), + action: 4, + } + )) + }) + .unwrap() + ); + assert_eq!( + IncompatiblePermissionError {}, + Permission::try_from(proto::Permission { + permission_one_of: Some(proto::permission::PermissionOneOf::ResourceAction( + proto::ResourceActionPermission { + resource_type: 0, + resource_id: Some("ns2".into()), + action: 4, + } + )) + }) + .unwrap_err() + ); + assert_eq!( + IncompatiblePermissionError {}, + Permission::try_from(proto::Permission { + permission_one_of: Some(proto::permission::PermissionOneOf::ResourceAction( + proto::ResourceActionPermission { + resource_type: 1, + resource_id: Some("ns2".into()), + action: 0, + } + )) + }) + .unwrap_err() + ); + } + + #[test] + fn permission_try_into_proto() { + assert_eq!( + proto::Permission { + permission_one_of: Some(proto::permission::PermissionOneOf::ResourceAction( + proto::ResourceActionPermission { + resource_type: 1, + resource_id: Some("ns3".into()), + action: 4, + } + )) + }, + proto::Permission::try_from(Permission::ResourceAction( + Resource::Database("ns3".into()), + Action::Create + )) + .unwrap() + ); + } +} diff --git a/backoff/Cargo.toml b/backoff/Cargo.toml new file mode 100644 index 0000000..484412f --- /dev/null +++ b/backoff/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "backoff" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[lints] +workspace = true + +[dependencies] +tokio = { version = "1.35", features = ["macros", "time"] } +observability_deps = { path = "../observability_deps" } +rand = "0.8" +snafu = "0.8" +workspace-hack = { version = "0.1", path = "../workspace-hack" } diff --git a/backoff/src/lib.rs b/backoff/src/lib.rs new file mode 100644 index 0000000..907847b --- /dev/null +++ b/backoff/src/lib.rs @@ -0,0 +1,399 @@ +//! Backoff functionality. +#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)] +#![warn( + missing_copy_implementations, + missing_debug_implementations, + missing_docs, + clippy::explicit_iter_loop, + // See https://github.com/influxdata/influxdb_iox/pull/1671 + clippy::future_not_send, + clippy::use_self, + clippy::clone_on_ref_ptr, + clippy::todo, + clippy::dbg_macro, + unused_crate_dependencies +)] + +// Workaround for "unused crate" lint false positives. +use workspace_hack as _; + +use observability_deps::tracing::warn; +use rand::prelude::*; +use snafu::Snafu; +use std::ops::ControlFlow; +use std::time::Duration; + +/// Exponential backoff with jitter +/// +/// See +#[derive(Debug, Clone, PartialEq)] +#[allow(missing_copy_implementations)] +pub struct BackoffConfig { + /// Initial backoff. + pub init_backoff: Duration, + + /// Maximum backoff. + pub max_backoff: Duration, + + /// Multiplier for each backoff round. + pub base: f64, + + /// Timeout until we try to retry. + pub deadline: Option, +} + +impl Default for BackoffConfig { + fn default() -> Self { + Self { + init_backoff: Duration::from_millis(100), + max_backoff: Duration::from_secs(500), + base: 3., + deadline: None, + } + } +} + +/// Error after giving up retrying. +#[derive(Debug, Snafu, PartialEq, Eq)] +#[allow(missing_copy_implementations, missing_docs)] +pub enum BackoffError +where + E: std::error::Error + 'static, +{ + #[snafu(display("Retry did not exceed within {deadline:?}: {source}"))] + DeadlineExceeded { deadline: Duration, source: E }, +} + +/// Backoff result. +pub type BackoffResult = Result>; + +/// [`Backoff`] can be created from a [`BackoffConfig`] +/// +/// Consecutive calls to [`Backoff::next`] will return the next backoff interval +/// +pub struct Backoff { + init_backoff: f64, + next_backoff_secs: f64, + max_backoff_secs: f64, + base: f64, + total: f64, + deadline: Option, + rng: Option>, +} + +impl std::fmt::Debug for Backoff { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Backoff") + .field("init_backoff", &self.init_backoff) + .field("next_backoff_secs", &self.next_backoff_secs) + .field("max_backoff_secs", &self.max_backoff_secs) + .field("base", &self.base) + .field("total", &self.total) + .field("deadline", &self.deadline) + .finish() + } +} + +impl Backoff { + /// Create a new [`Backoff`] from the provided [`BackoffConfig`]. + /// + /// # Pancis + /// Panics if [`BackoffConfig::base`] is not finite or < 1.0. + pub fn new(config: &BackoffConfig) -> Self { + Self::new_with_rng(config, None) + } + + /// Creates a new `Backoff` with the optional `rng`. + /// + /// Used [`rand::thread_rng()`] if no rng provided. + /// + /// See [`new`](Self::new) for panic handling. + pub fn new_with_rng( + config: &BackoffConfig, + rng: Option>, + ) -> Self { + assert!( + config.base.is_finite(), + "Backoff base ({}) must be finite.", + config.base, + ); + assert!( + config.base >= 1.0, + "Backoff base ({}) must be greater or equal than 1.", + config.base, + ); + + let max_backoff = config.max_backoff.as_secs_f64(); + let init_backoff = config.init_backoff.as_secs_f64().min(max_backoff); + Self { + init_backoff, + next_backoff_secs: init_backoff, + max_backoff_secs: max_backoff, + base: config.base, + total: 0.0, + deadline: config.deadline.map(|d| d.as_secs_f64()), + rng, + } + } + + /// Fade this backoff over to a different backoff config. + pub fn fade_to(&mut self, config: &BackoffConfig) { + // Note: `new` won't have the same RNG, but this doesn't matter + let new = Self::new(config); + + *self = Self { + init_backoff: new.init_backoff, + next_backoff_secs: self.next_backoff_secs, + max_backoff_secs: new.max_backoff_secs, + base: new.base, + total: self.total, + deadline: new.deadline, + rng: self.rng.take(), + }; + } + + /// Perform an async operation that retries with a backoff + pub async fn retry_with_backoff( + &mut self, + task_name: &str, + mut do_stuff: F, + ) -> BackoffResult + where + F: (FnMut() -> F1) + Send, + F1: std::future::Future> + Send, + E: std::error::Error + Send + 'static, + { + let mut fail_count = 0_usize; + loop { + // first execute `F` and then use it, so we can avoid `F: Sync`. + let do_stuff = do_stuff(); + + let e = match do_stuff.await { + ControlFlow::Break(r) => break Ok(r), + ControlFlow::Continue(e) => e, + }; + + let backoff = match self.next() { + Some(backoff) => backoff, + None => { + return Err(BackoffError::DeadlineExceeded { + deadline: Duration::from_secs_f64(self.deadline.expect("deadline")), + source: e, + }); + } + }; + + fail_count += 1; + + warn!( + error=%e, + task_name, + backoff_secs = backoff.as_secs(), + fail_count, + "request encountered non-fatal error - backing off", + ); + tokio::time::sleep(backoff).await; + } + } + + /// Retry all errors. + pub async fn retry_all_errors( + &mut self, + task_name: &str, + mut do_stuff: F, + ) -> BackoffResult + where + F: (FnMut() -> F1) + Send, + F1: std::future::Future> + Send, + E: std::error::Error + Send + 'static, + { + self.retry_with_backoff(task_name, move || { + // first execute `F` and then use it, so we can avoid `F: Sync`. + let do_stuff = do_stuff(); + + async { + match do_stuff.await { + Ok(b) => ControlFlow::Break(b), + Err(e) => ControlFlow::Continue(e), + } + } + }) + .await + } +} + +impl Iterator for Backoff { + type Item = Duration; + + /// Returns the next backoff duration to wait for, if any + fn next(&mut self) -> Option { + let range = self.init_backoff..=(self.next_backoff_secs * self.base); + + let rand_backoff = match self.rng.as_mut() { + Some(rng) => rng.gen_range(range), + None => thread_rng().gen_range(range), + }; + + let next_backoff = self.max_backoff_secs.min(rand_backoff); + self.total += next_backoff; + let res = std::mem::replace(&mut self.next_backoff_secs, next_backoff); + if let Some(deadline) = self.deadline { + if self.total >= deadline { + return None; + } + } + duration_try_from_secs_f64(res) + } +} + +const MAX_F64_SECS: f64 = 1_000_000.0; + +/// Try to get `Duration` from `f64` secs. +/// +/// This is required till is resolved. +fn duration_try_from_secs_f64(secs: f64) -> Option { + (secs.is_finite() && (0.0..=MAX_F64_SECS).contains(&secs)) + .then(|| Duration::from_secs_f64(secs)) +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::mock::StepRng; + + #[test] + fn test_backoff() { + let init_backoff_secs = 1.; + let max_backoff_secs = 500.; + let base = 3.; + + let config = BackoffConfig { + init_backoff: Duration::from_secs_f64(init_backoff_secs), + max_backoff: Duration::from_secs_f64(max_backoff_secs), + deadline: None, + base, + }; + + let assert_fuzzy_eq = |a: f64, b: f64| assert!((b - a).abs() < 0.0001, "{a} != {b}"); + + // Create a static rng that takes the minimum of the range + let rng = Box::new(StepRng::new(0, 0)); + let mut backoff = Backoff::new_with_rng(&config, Some(rng)); + + for _ in 0..20 { + assert_eq!(backoff.next().unwrap().as_secs_f64(), init_backoff_secs); + } + + // Create a static rng that takes the maximum of the range + let rng = Box::new(StepRng::new(u64::MAX, 0)); + let mut backoff = Backoff::new_with_rng(&config, Some(rng)); + + for i in 0..20 { + let value = (base.powi(i) * init_backoff_secs).min(max_backoff_secs); + assert_fuzzy_eq(backoff.next().unwrap().as_secs_f64(), value); + } + + // Create a static rng that takes the mid point of the range + let rng = Box::new(StepRng::new(u64::MAX / 2, 0)); + let mut backoff = Backoff::new_with_rng(&config, Some(rng)); + + let mut value = init_backoff_secs; + for _ in 0..20 { + assert_fuzzy_eq(backoff.next().unwrap().as_secs_f64(), value); + value = + (init_backoff_secs + (value * base - init_backoff_secs) / 2.).min(max_backoff_secs); + } + + // deadline + let rng = Box::new(StepRng::new(u64::MAX, 0)); + let deadline = Duration::from_secs_f64(init_backoff_secs); + let mut backoff = Backoff::new_with_rng( + &BackoffConfig { + deadline: Some(deadline), + ..config + }, + Some(rng), + ); + assert_eq!(backoff.next(), None); + } + + #[test] + fn test_overflow() { + let rng = Box::new(StepRng::new(u64::MAX, 0)); + let cfg = BackoffConfig { + init_backoff: Duration::MAX, + max_backoff: Duration::MAX, + ..Default::default() + }; + let mut backoff = Backoff::new_with_rng(&cfg, Some(rng)); + assert_eq!(backoff.next(), None); + } + + #[test] + fn test_duration_try_from_f64() { + for d in [-0.1, f64::INFINITY, f64::NAN, MAX_F64_SECS + 0.1] { + assert!(duration_try_from_secs_f64(d).is_none()); + } + + for d in [0.0, MAX_F64_SECS] { + assert!(duration_try_from_secs_f64(d).is_some()); + } + } + + #[test] + fn test_max_backoff_smaller_init() { + let rng = Box::new(StepRng::new(u64::MAX, 0)); + let cfg = BackoffConfig { + init_backoff: Duration::from_secs(2), + max_backoff: Duration::from_secs(1), + ..Default::default() + }; + let mut backoff = Backoff::new_with_rng(&cfg, Some(rng)); + assert_eq!(backoff.next(), Some(Duration::from_secs(1))); + assert_eq!(backoff.next(), Some(Duration::from_secs(1))); + } + + #[test] + #[should_panic(expected = "Backoff base (inf) must be finite.")] + fn test_panic_inf_base() { + let cfg = BackoffConfig { + base: f64::INFINITY, + ..Default::default() + }; + Backoff::new(&cfg); + } + + #[test] + #[should_panic(expected = "Backoff base (NaN) must be finite.")] + fn test_panic_nan_base() { + let cfg = BackoffConfig { + base: f64::NAN, + ..Default::default() + }; + Backoff::new(&cfg); + } + + #[test] + #[should_panic(expected = "Backoff base (0) must be greater or equal than 1.")] + fn test_panic_zero_base() { + let cfg = BackoffConfig { + base: 0.0, + ..Default::default() + }; + Backoff::new(&cfg); + } + + #[test] + fn test_constant_backoff() { + let rng = Box::new(StepRng::new(u64::MAX, 0)); + let cfg = BackoffConfig { + init_backoff: Duration::from_secs(1), + max_backoff: Duration::from_secs(1), + base: 1.0, + ..Default::default() + }; + let mut backoff = Backoff::new_with_rng(&cfg, Some(rng)); + assert_eq!(backoff.next(), Some(Duration::from_secs(1))); + assert_eq!(backoff.next(), Some(Duration::from_secs(1))); + } +} diff --git a/cache_system/Cargo.toml b/cache_system/Cargo.toml new file mode 100644 index 0000000..bb07eba --- /dev/null +++ b/cache_system/Cargo.toml @@ -0,0 +1,40 @@ +[package] +name = "cache_system" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[lints] +workspace = true + +[dependencies] +async-trait = "0.1.77" +backoff = { path = "../backoff" } +futures = "0.3" +iox_time = { path = "../iox_time" } +metric = { path = "../metric" } +observability_deps = { path = "../observability_deps" } +ouroboros = "0.18" +parking_lot = { version = "0.12", features = ["arc_lock"] } +pdatastructs = { version = "0.7", default-features = false, features = ["fixedbitset"] } +rand = "0.8.3" +tokio = { version = "1.35", features = ["macros", "parking_lot", "rt-multi-thread", "sync", "time"] } +tokio-util = { version = "0.7.10" } +trace = { path = "../trace"} +tracker = { path = "../tracker"} +workspace-hack = { version = "0.1", path = "../workspace-hack" } + +[dev-dependencies] +criterion = { version = "0.5", default-features = false, features = ["rayon"]} +proptest = { version = "1", default_features = false, features = ["std"] } +test_helpers = { path = "../test_helpers" } + +[lib] +# Allow --save-baseline to work +# https://github.com/bheisler/criterion.rs/issues/275 +bench = false + +[[bench]] +name = "addressable_heap" +harness = false diff --git a/cache_system/benches/addressable_heap.rs b/cache_system/benches/addressable_heap.rs new file mode 100644 index 0000000..42a9e8b --- /dev/null +++ b/cache_system/benches/addressable_heap.rs @@ -0,0 +1,420 @@ +use std::mem::size_of; + +use cache_system::addressable_heap::AddressableHeap; +use criterion::{ + criterion_group, criterion_main, measurement::WallTime, AxisScale, BatchSize, BenchmarkGroup, + BenchmarkId, Criterion, PlotConfiguration, SamplingMode, +}; +use rand::{prelude::SliceRandom, thread_rng, Rng}; + +/// Payload (`V`) for testing. +/// +/// This is a 64bit-wide object which is enough to store a [`Box`] or a [`usize`]. +#[derive(Debug, Clone, Default)] +struct Payload([u8; 8]); + +const _: () = assert!(size_of::() == 8); +const _: () = assert!(size_of::() >= size_of::>>()); +const _: () = assert!(size_of::() >= size_of::()); + +type TestHeap = AddressableHeap; + +const TEST_SIZES: &[usize] = &[0, 1, 10, 100, 1_000, 10_000]; + +#[derive(Debug, Clone)] +struct Entry { + k: u64, + o: u64, +} + +impl Entry { + fn new_random(rng: &mut R) -> Self + where + R: Rng, + { + Self { + // leave some room at the top and bottom + k: (rng.gen::() << 1) + (u64::MAX << 2), + // leave some room at the top and bottom + o: (rng.gen::() << 1) + (u64::MAX << 2), + } + } + + fn new_random_n(rng: &mut R, n: usize) -> Vec + where + R: Rng, + { + (0..n).map(|_| Self::new_random(rng)).collect() + } +} + +fn create_filled_heap(rng: &mut R, n: usize) -> (TestHeap, Vec) +where + R: Rng, +{ + let mut heap = TestHeap::default(); + let mut entries = Vec::with_capacity(n); + + for _ in 0..n { + let entry = Entry::new_random(rng); + heap.insert(entry.k, Payload::default(), entry.o); + entries.push(entry); + } + + (heap, entries) +} + +fn setup_group(g: &mut BenchmarkGroup<'_, WallTime>) { + g.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); + g.sampling_mode(SamplingMode::Flat); +} + +fn bench_insert_n_elements(c: &mut Criterion) { + let mut g = c.benchmark_group("insert_n_elements"); + setup_group(&mut g); + + let mut rng = thread_rng(); + + for n in TEST_SIZES { + g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| { + b.iter_batched( + || (TestHeap::default(), Entry::new_random_n(&mut rng, *n)), + |(mut heap, entries)| { + for entry in &entries { + heap.insert(entry.k, Payload::default(), entry.o); + } + + // let criterion handle the drop + (heap, entries) + }, + BatchSize::LargeInput, + ); + }); + } + + g.finish(); +} + +fn bench_peek_after_n_elements(c: &mut Criterion) { + let mut g = c.benchmark_group("peek_after_n_elements"); + setup_group(&mut g); + + let mut rng = thread_rng(); + + for n in TEST_SIZES { + g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| { + b.iter_batched( + || create_filled_heap(&mut rng, *n).0, + |heap| { + heap.peek(); + + // let criterion handle the drop + heap + }, + BatchSize::LargeInput, + ); + }); + } + + g.finish(); +} + +fn bench_get_existing_after_n_elements(c: &mut Criterion) { + let mut g = c.benchmark_group("get_existing_after_n_elements"); + setup_group(&mut g); + + let mut rng = thread_rng(); + + for n in TEST_SIZES { + if *n == 0 { + continue; + } + + g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| { + b.iter_batched( + || { + let (heap, entries) = create_filled_heap(&mut rng, *n); + let entry = entries.choose(&mut rng).unwrap().clone(); + (heap, entry) + }, + |(heap, entry)| { + heap.get(&entry.k); + + // let criterion handle the drop + heap + }, + BatchSize::LargeInput, + ); + }); + } + + g.finish(); +} + +fn bench_get_new_after_n_elements(c: &mut Criterion) { + let mut g = c.benchmark_group("get_new_after_n_elements"); + setup_group(&mut g); + + let mut rng = thread_rng(); + + for n in TEST_SIZES { + g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| { + b.iter_batched( + || { + let (heap, _entries) = create_filled_heap(&mut rng, *n); + let entry = Entry::new_random(&mut rng); + (heap, entry) + }, + |(heap, entry)| { + heap.get(&entry.k); + + // let criterion handle the drop + heap + }, + BatchSize::LargeInput, + ); + }); + } + + g.finish(); +} + +fn bench_pop_n_elements(c: &mut Criterion) { + let mut g = c.benchmark_group("pop_n_elements"); + setup_group(&mut g); + + let mut rng = thread_rng(); + + for n in TEST_SIZES { + g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| { + b.iter_batched( + || create_filled_heap(&mut rng, *n).0, + |mut heap| { + for _ in 0..*n { + heap.pop(); + } + + // let criterion handle the drop + heap + }, + BatchSize::LargeInput, + ); + }); + } + + g.finish(); +} + +fn bench_remove_existing_after_n_elements(c: &mut Criterion) { + let mut g = c.benchmark_group("remove_existing_after_n_elements"); + setup_group(&mut g); + + let mut rng = thread_rng(); + + for n in TEST_SIZES { + if *n == 0 { + continue; + } + + g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| { + b.iter_batched( + || { + let (heap, entries) = create_filled_heap(&mut rng, *n); + let entry = entries.choose(&mut rng).unwrap().clone(); + (heap, entry) + }, + |(mut heap, entry)| { + heap.remove(&entry.k); + + // let criterion handle the drop + heap + }, + BatchSize::LargeInput, + ); + }); + } + + g.finish(); +} + +fn bench_remove_new_after_n_elements(c: &mut Criterion) { + let mut g = c.benchmark_group("remove_new_after_n_elements"); + setup_group(&mut g); + + let mut rng = thread_rng(); + + for n in TEST_SIZES { + g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| { + b.iter_batched( + || { + let (heap, _entries) = create_filled_heap(&mut rng, *n); + let entry = Entry::new_random(&mut rng); + (heap, entry) + }, + |(mut heap, entry)| { + heap.remove(&entry.k); + + // let criterion handle the drop + heap + }, + BatchSize::LargeInput, + ); + }); + } + + g.finish(); +} + +fn bench_replace_after_n_elements(c: &mut Criterion) { + let mut g = c.benchmark_group("replace_after_n_elements"); + setup_group(&mut g); + + let mut rng = thread_rng(); + + for n in TEST_SIZES { + if *n == 0 { + continue; + } + + g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| { + b.iter_batched( + || { + let (heap, entries) = create_filled_heap(&mut rng, *n); + let entry = entries.choose(&mut rng).unwrap().clone(); + let entry = Entry { + k: entry.k, + o: Entry::new_random(&mut rng).o, + }; + (heap, entry) + }, + |(mut heap, entry)| { + heap.insert(entry.k, Payload::default(), entry.o); + + // let criterion handle the drop + heap + }, + BatchSize::LargeInput, + ); + }); + } + + g.finish(); +} + +fn bench_update_order_existing_to_random_after_n_elements(c: &mut Criterion) { + let mut g = c.benchmark_group("update_order_existing_to_random_after_n_elements"); + setup_group(&mut g); + + let mut rng = thread_rng(); + + for n in TEST_SIZES { + if *n == 0 { + continue; + } + + g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| { + b.iter_batched( + || { + let (heap, entries) = create_filled_heap(&mut rng, *n); + let entry = entries.choose(&mut rng).unwrap().clone(); + let entry = Entry { + k: entry.k, + o: Entry::new_random(&mut rng).o, + }; + (heap, entry) + }, + |(mut heap, entry)| { + heap.update_order(&entry.k, entry.o); + + // let criterion handle the drop + heap + }, + BatchSize::LargeInput, + ); + }); + } + + g.finish(); +} + +fn bench_update_order_existing_to_last_after_n_elements(c: &mut Criterion) { + let mut g = c.benchmark_group("update_order_existing_to_first_after_n_elements"); + setup_group(&mut g); + + let mut rng = thread_rng(); + + for n in TEST_SIZES { + if *n == 0 { + continue; + } + + g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| { + b.iter_batched( + || { + let (heap, entries) = create_filled_heap(&mut rng, *n); + let entry = entries.choose(&mut rng).unwrap().clone(); + let entry = Entry { + k: entry.k, + o: u64::MAX - (u64::MAX << 2), + }; + (heap, entry) + }, + |(mut heap, entry)| { + heap.update_order(&entry.k, entry.o); + + // let criterion handle the drop + heap + }, + BatchSize::LargeInput, + ); + }); + } + + g.finish(); +} + +fn bench_update_order_new_after_n_elements(c: &mut Criterion) { + let mut g = c.benchmark_group("update_order_new_after_n_elements"); + setup_group(&mut g); + + let mut rng = thread_rng(); + + for n in TEST_SIZES { + g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| { + b.iter_batched( + || { + let (heap, _entries) = create_filled_heap(&mut rng, *n); + let entry = Entry::new_random(&mut rng); + (heap, entry) + }, + |(mut heap, entry)| { + heap.update_order(&entry.k, entry.o); + + // let criterion handle the drop + heap + }, + BatchSize::LargeInput, + ); + }); + } + + g.finish(); +} + +criterion_group! { + name = benches; + config = Criterion::default(); + targets = + bench_insert_n_elements, + bench_peek_after_n_elements, + bench_get_existing_after_n_elements, + bench_get_new_after_n_elements, + bench_pop_n_elements, + bench_remove_existing_after_n_elements, + bench_remove_new_after_n_elements, + bench_replace_after_n_elements, + bench_update_order_existing_to_random_after_n_elements, + bench_update_order_existing_to_last_after_n_elements, + bench_update_order_new_after_n_elements, +} +criterion_main!(benches); diff --git a/cache_system/src/addressable_heap.rs b/cache_system/src/addressable_heap.rs new file mode 100644 index 0000000..4f3466f --- /dev/null +++ b/cache_system/src/addressable_heap.rs @@ -0,0 +1,611 @@ +//! Implementation of an [`AddressableHeap`]. +use std::{ + collections::{hash_map, BTreeSet, HashMap}, + hash::Hash, +}; + +/// Addressable heap. +/// +/// Stores a value `V` together with a key `K` and an order `O`. Elements are sorted by `O` and the smallest element can +/// be peeked/popped. At the same time elements can be addressed via `K`. +/// +/// Note that `K` requires the inner data structure to implement [`Ord`] as a tie breaker. +#[derive(Debug, Clone)] +pub struct AddressableHeap +where + K: Clone + Eq + Hash + Ord, + O: Clone + Ord, +{ + /// Key to order and value. + /// + /// The order is required to lookup data within the queue. + /// + /// The value is stored here instead of the queue since HashMap entries are copied around less often than queue elements. + key_to_order_and_value: HashMap, + + /// Queue that handles the priorities. + /// + /// The order goes first, the key goes second. + /// + /// Note: This is not really a heap, but it fulfills the interface that we need. + queue: BTreeSet<(O, K)>, +} + +impl AddressableHeap +where + K: Clone + Eq + Hash + Ord, + O: Clone + Ord, +{ + /// Create new, empty heap. + pub fn new() -> Self { + Self { + key_to_order_and_value: HashMap::new(), + queue: BTreeSet::new(), + } + } + + /// Check if the heap is empty. + pub fn is_empty(&self) -> bool { + let res1 = self.key_to_order_and_value.is_empty(); + let res2 = self.queue.is_empty(); + assert_eq!(res1, res2, "data structures out of sync"); + res1 + } + + /// Insert element. + /// + /// If the element (compared by `K`) already exists, it will be returned. + pub fn insert(&mut self, k: K, v: V, o: O) -> Option<(V, O)> { + let (result, k) = match self.key_to_order_and_value.entry(k.clone()) { + hash_map::Entry::Occupied(mut entry_o) => { + // `entry_o.replace_entry(...)` is not stabel yet, see https://github.com/rust-lang/rust/issues/44286 + let mut tmp = (v, o.clone()); + std::mem::swap(&mut tmp, entry_o.get_mut()); + let (v_old, o_old) = tmp; + + let query = (o_old, k); + let existed = self.queue.remove(&query); + assert!(existed, "key was in key_to_order"); + let (o_old, k) = query; + + (Some((v_old, o_old)), k) + } + hash_map::Entry::Vacant(entry_v) => { + entry_v.insert((v, o.clone())); + (None, k) + } + }; + + let inserted = self.queue.insert((o, k)); + assert!(inserted, "entry should have been removed by now"); + + result + } + + /// Peek first element (by smallest `O`). + pub fn peek(&self) -> Option<(&K, &V, &O)> { + self.iter().next() + } + + /// Pop first element (by smallest `O`) from heap. + pub fn pop(&mut self) -> Option<(K, V, O)> { + if let Some((o, k)) = self.queue.pop_first() { + let (v, o2) = self + .key_to_order_and_value + .remove(&k) + .expect("value is in queue"); + assert!(o == o2); + Some((k, v, o)) + } else { + None + } + } + + /// Iterate over elements in order of `O` (starting at smallest). + /// + /// This is equivalent to calling [`pop`](Self::pop) multiple times, but does NOT modify the collection. + pub fn iter(&self) -> AddressableHeapIter<'_, K, V, O> { + AddressableHeapIter { + key_to_order_and_value: &self.key_to_order_and_value, + queue_iter: self.queue.iter(), + } + } + + /// Get element by key. + pub fn get(&self, k: &K) -> Option<(&V, &O)> { + self.key_to_order_and_value.get(k).map(project_tuple) + } + + /// Remove element by key. + /// + /// If the element exists within the heap (addressed via `K`), the value and order will be returned. + pub fn remove(&mut self, k: &K) -> Option<(V, O)> { + if let Some((k, (v, o))) = self.key_to_order_and_value.remove_entry(k) { + let query = (o, k); + let existed = self.queue.remove(&query); + assert!(existed, "key was in key_to_order"); + let (o, _k) = query; + Some((v, o)) + } else { + None + } + } + + /// Update order of a given key. + /// + /// Returns existing order if the key existed. + pub fn update_order(&mut self, k: &K, o: O) -> Option { + match self.key_to_order_and_value.get_mut(k) { + Some(entry) => { + let mut o_old = o.clone(); + std::mem::swap(&mut entry.1, &mut o_old); + + let query = (o_old, k.clone()); + let existed = self.queue.remove(&query); + assert!(existed, "key was in key_to_order"); + let (o_old, k) = query; + + let inserted = self.queue.insert((o, k)); + assert!(inserted, "entry should have been removed by now"); + + Some(o_old) + } + None => None, + } + } +} + +impl Default for AddressableHeap +where + K: Clone + Eq + Hash + Ord, + O: Clone + Ord, +{ + fn default() -> Self { + Self::new() + } +} + +/// Project tuple references. +fn project_tuple(t: &(A, B)) -> (&A, &B) { + (&t.0, &t.1) +} + +/// Iterator of [`AddressableHeap::iter`]. +#[derive(Debug)] +pub struct AddressableHeapIter<'a, K, V, O> +where + K: Clone + Eq + Hash + Ord, + O: Clone + Ord, +{ + key_to_order_and_value: &'a HashMap, + queue_iter: std::collections::btree_set::Iter<'a, (O, K)>, +} + +impl<'a, K, V, O> Iterator for AddressableHeapIter<'a, K, V, O> +where + K: Clone + Eq + Hash + Ord, + O: Clone + Ord, +{ + type Item = (&'a K, &'a V, &'a O); + + fn next(&mut self) -> Option { + self.queue_iter.next().map(|(o, k)| { + let (v, o2) = self + .key_to_order_and_value + .get(k) + .expect("value is in queue"); + assert!(o == o2); + (k, v, o) + }) + } + + fn size_hint(&self) -> (usize, Option) { + self.queue_iter.size_hint() + } +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use super::*; + + #[test] + fn test_peek_empty() { + let heap = AddressableHeap::::new(); + + assert_eq!(heap.peek(), None); + } + + #[test] + fn test_peek_some() { + let mut heap = AddressableHeap::new(); + + heap.insert(1, "a", 4); + heap.insert(2, "b", 3); + heap.insert(3, "c", 5); + + assert_eq!(heap.peek(), Some((&2, &"b", &3))); + } + + #[test] + fn test_peek_tie() { + let mut heap = AddressableHeap::new(); + + heap.insert(3, "a", 1); + heap.insert(1, "b", 1); + heap.insert(2, "c", 1); + + assert_eq!(heap.peek(), Some((&1, &"b", &1))); + } + + #[test] + fn test_peek_after_remove() { + let mut heap = AddressableHeap::new(); + + heap.insert(1, "a", 4); + heap.insert(2, "b", 3); + heap.insert(3, "c", 5); + + assert_eq!(heap.peek(), Some((&2, &"b", &3))); + heap.remove(&3); + assert_eq!(heap.peek(), Some((&2, &"b", &3))); + heap.remove(&2); + assert_eq!(heap.peek(), Some((&1, &"a", &4))); + heap.remove(&1); + assert_eq!(heap.peek(), None); + } + + #[test] + fn test_peek_after_override() { + let mut heap = AddressableHeap::new(); + + heap.insert(1, "a", 4); + heap.insert(2, "b", 3); + heap.insert(1, "c", 2); + + assert_eq!(heap.peek(), Some((&1, &"c", &2))); + } + + #[test] + fn test_pop_empty() { + let mut heap = AddressableHeap::::new(); + + assert_eq!(heap.pop(), None); + } + + #[test] + fn test_pop_all() { + let mut heap = AddressableHeap::new(); + + heap.insert(1, "a", 4); + heap.insert(2, "b", 3); + heap.insert(3, "c", 5); + + assert_eq!(heap.pop(), Some((2, "b", 3))); + assert_eq!(heap.pop(), Some((1, "a", 4))); + assert_eq!(heap.pop(), Some((3, "c", 5))); + assert_eq!(heap.pop(), None); + } + + #[test] + fn test_pop_tie() { + let mut heap = AddressableHeap::new(); + + heap.insert(3, "a", 1); + heap.insert(1, "b", 1); + heap.insert(2, "c", 1); + + assert_eq!(heap.pop(), Some((1, "b", 1))); + assert_eq!(heap.pop(), Some((2, "c", 1))); + assert_eq!(heap.pop(), Some((3, "a", 1))); + assert_eq!(heap.pop(), None); + } + + #[test] + fn test_pop_after_insert() { + let mut heap = AddressableHeap::new(); + + heap.insert(1, "a", 4); + heap.insert(2, "b", 3); + heap.insert(3, "c", 5); + + assert_eq!(heap.pop(), Some((2, "b", 3))); + + heap.insert(4, "d", 2); + assert_eq!(heap.pop(), Some((4, "d", 2))); + assert_eq!(heap.pop(), Some((1, "a", 4))); + } + + #[test] + fn test_pop_after_remove() { + let mut heap = AddressableHeap::new(); + + heap.insert(1, "a", 4); + heap.insert(2, "b", 3); + heap.insert(3, "c", 5); + + heap.remove(&2); + assert_eq!(heap.pop(), Some((1, "a", 4))); + } + + #[test] + fn test_pop_after_override() { + let mut heap = AddressableHeap::new(); + + heap.insert(1, "a", 4); + heap.insert(2, "b", 3); + heap.insert(1, "c", 2); + + assert_eq!(heap.pop(), Some((1, "c", 2))); + assert_eq!(heap.pop(), Some((2, "b", 3))); + assert_eq!(heap.pop(), None); + } + + #[test] + fn test_get_empty() { + let heap = AddressableHeap::::new(); + + assert_eq!(heap.get(&1), None); + } + + #[test] + fn test_get_multiple() { + let mut heap = AddressableHeap::new(); + + heap.insert(1, "a", 4); + heap.insert(2, "b", 3); + + assert_eq!(heap.get(&1), Some((&"a", &4))); + assert_eq!(heap.get(&2), Some((&"b", &3))); + } + + #[test] + fn test_get_after_remove() { + let mut heap = AddressableHeap::new(); + + heap.insert(1, "a", 4); + heap.insert(2, "b", 3); + + heap.remove(&1); + + assert_eq!(heap.get(&1), None); + assert_eq!(heap.get(&2), Some((&"b", &3))); + } + + #[test] + fn test_get_after_pop() { + let mut heap = AddressableHeap::new(); + + heap.insert(1, "a", 4); + heap.insert(2, "b", 3); + + heap.pop(); + + assert_eq!(heap.get(&1), Some((&"a", &4))); + assert_eq!(heap.get(&2), None); + } + + #[test] + fn test_get_after_override() { + let mut heap = AddressableHeap::new(); + + heap.insert(1, "a", 4); + heap.insert(1, "b", 3); + + assert_eq!(heap.get(&1), Some((&"b", &3))); + } + + #[test] + fn test_remove_empty() { + let mut heap = AddressableHeap::::new(); + + assert_eq!(heap.remove(&1), None); + } + + #[test] + fn test_remove_some() { + let mut heap = AddressableHeap::new(); + + heap.insert(1, "a", 4); + heap.insert(2, "b", 3); + + assert_eq!(heap.remove(&1), Some(("a", 4))); + assert_eq!(heap.remove(&2), Some(("b", 3))); + } + + #[test] + fn test_remove_twice() { + let mut heap = AddressableHeap::new(); + + heap.insert(1, "a", 4); + + assert_eq!(heap.remove(&1), Some(("a", 4))); + assert_eq!(heap.remove(&1), None); + } + + #[test] + fn test_remove_after_pop() { + let mut heap = AddressableHeap::new(); + + heap.insert(1, "a", 4); + heap.insert(2, "b", 3); + + heap.pop(); + + assert_eq!(heap.remove(&1), Some(("a", 4))); + assert_eq!(heap.remove(&2), None); + } + + #[test] + fn test_remove_after_override() { + let mut heap = AddressableHeap::new(); + + heap.insert(1, "a", 4); + heap.insert(1, "b", 3); + + assert_eq!(heap.remove(&1), Some(("b", 3))); + assert_eq!(heap.remove(&1), None); + } + + #[test] + fn test_override() { + let mut heap = AddressableHeap::new(); + + assert_eq!(heap.insert(1, "a", 4), None); + assert_eq!(heap.insert(2, "b", 3), None); + assert_eq!(heap.insert(1, "c", 5), Some(("a", 4))); + } + + /// Simple version of [`AddressableHeap`] for testing. + struct SimpleAddressableHeap { + inner: Vec<(u8, String, i8)>, + } + + impl SimpleAddressableHeap { + fn new() -> Self { + Self { inner: Vec::new() } + } + + fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + fn insert(&mut self, k: u8, v: String, o: i8) -> Option<(String, i8)> { + let res = self.remove(&k); + self.inner.push((k, v, o)); + + res + } + + #[allow(clippy::map_identity)] // https://github.com/rust-lang/rust-clippy/issues/11764 + fn peek(&self) -> Option<(&u8, &String, &i8)> { + self.inner + .iter() + .min_by_key(|(k, _v, o)| (o, k)) + .map(|(k, v, o)| (k, v, o)) + } + + fn dump_ordered(&self) -> Vec<(u8, String, i8)> { + let mut inner = self.inner.clone(); + inner.sort_by_key(|(k, _v, o)| (*o, *k)); + inner + } + + fn pop(&mut self) -> Option<(u8, String, i8)> { + self.inner + .iter() + .enumerate() + .min_by_key(|(_idx, (k, _v, o))| (o, k)) + .map(|(idx, _)| idx) + .map(|idx| self.inner.remove(idx)) + } + + fn get(&self, k: &u8) -> Option<(&String, &i8)> { + self.inner + .iter() + .find(|(k2, _v, _o)| k2 == k) + .map(|(_k, v, o)| (v, o)) + } + + fn remove(&mut self, k: &u8) -> Option<(String, i8)> { + self.inner + .iter() + .enumerate() + .find(|(_idx, (k2, _v, _o))| k2 == k) + .map(|(idx, _)| idx) + .map(|idx| { + let (_k, v, o) = self.inner.remove(idx); + (v, o) + }) + } + + fn update_order(&mut self, k: &u8, o: i8) -> Option { + if let Some((v, o_old)) = self.remove(k) { + self.insert(*k, v, o); + Some(o_old) + } else { + None + } + } + } + + #[derive(Debug, Clone)] + enum Action { + IsEmpty, + Insert { k: u8, v: String, o: i8 }, + Peek, + Iter, + Pop, + Get { k: u8 }, + Remove { k: u8 }, + UpdateOrder { k: u8, o: i8 }, + } + + // Use a hand-rolled strategy instead of `proptest-derive`, because the latter one is quite a heavy dependency. + fn action() -> impl Strategy { + prop_oneof![ + Just(Action::IsEmpty), + (any::(), ".*", any::()).prop_map(|(k, v, o)| Action::Insert { k, v, o }), + Just(Action::Peek), + Just(Action::Iter), + Just(Action::Pop), + any::().prop_map(|k| Action::Get { k }), + any::().prop_map(|k| Action::Remove { k }), + (any::(), any::()).prop_map(|(k, o)| Action::UpdateOrder { k, o }), + ] + } + + proptest! { + #[test] + fn test_proptest(actions in prop::collection::vec(action(), 0..100)) { + let mut heap = AddressableHeap::new(); + let mut sim = SimpleAddressableHeap::new(); + + for action in actions { + match action { + Action::IsEmpty => { + let res1 = heap.is_empty(); + let res2 = sim.is_empty(); + assert_eq!(res1, res2); + } + Action::Insert{k, v, o} => { + let res1 = heap.insert(k, v.clone(), o); + let res2 = sim.insert(k, v, o); + assert_eq!(res1, res2); + } + Action::Peek => { + let res1 = heap.peek(); + let res2 = sim.peek(); + assert_eq!(res1, res2); + } + Action::Iter => { + let res1 = heap.iter().map(|(k, v, o)| (*k, v.clone(), *o)).collect::>(); + let res2 = sim.dump_ordered(); + assert_eq!(res1, res2); + } + Action::Pop => { + let res1 = heap.pop(); + let res2 = sim.pop(); + assert_eq!(res1, res2); + } + Action::Get{k} => { + let res1 = heap.get(&k); + let res2 = sim.get(&k); + assert_eq!(res1, res2); + } + Action::Remove{k} => { + let res1 = heap.remove(&k); + let res2 = sim.remove(&k); + assert_eq!(res1, res2); + } + Action::UpdateOrder{k, o} => { + let res1 = heap.update_order(&k, o); + let res2 = sim.update_order(&k, o); + assert_eq!(res1, res2); + } + } + } + } + } +} diff --git a/cache_system/src/backend/hash_map.rs b/cache_system/src/backend/hash_map.rs new file mode 100644 index 0000000..cb3c302 --- /dev/null +++ b/cache_system/src/backend/hash_map.rs @@ -0,0 +1,51 @@ +//! Implements [`CacheBackend`] for [`HashMap`]. +use std::{ + any::Any, + collections::HashMap, + fmt::Debug, + hash::{BuildHasher, Hash}, +}; + +use super::CacheBackend; + +impl CacheBackend for HashMap +where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, + S: BuildHasher + Send + 'static, +{ + type K = K; + type V = V; + + fn get(&mut self, k: &Self::K) -> Option { + Self::get(self, k).cloned() + } + + fn set(&mut self, k: Self::K, v: Self::V) { + self.insert(k, v); + } + + fn remove(&mut self, k: &Self::K) { + self.remove(k); + } + + fn is_empty(&self) -> bool { + self.is_empty() + } + + fn as_any(&self) -> &dyn Any { + self as &dyn Any + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generic() { + use crate::backend::test_util::test_generic; + + test_generic(HashMap::new); + } +} diff --git a/cache_system/src/backend/mod.rs b/cache_system/src/backend/mod.rs new file mode 100644 index 0000000..8395c83 --- /dev/null +++ b/cache_system/src/backend/mod.rs @@ -0,0 +1,66 @@ +//! Storage backends to keep and manage cached entries. +use std::{any::Any, fmt::Debug, hash::Hash}; + +pub mod hash_map; +pub mod policy; + +#[cfg(test)] +mod test_util; + +/// Backend to keep and manage stored entries. +/// +/// A backend might remove entries at any point, e.g. due to memory pressure or expiration. +pub trait CacheBackend: Debug { + /// Cache key. + type K: Clone + Eq + Hash + Ord + Debug + Send + 'static; + + /// Cached value. + type V: Clone + Debug + Send + 'static; + + /// Get value for given key if it exists. + fn get(&mut self, k: &Self::K) -> Option; + + /// Set value for given key. + /// + /// It is OK to set and override a key that already exists. + fn set(&mut self, k: Self::K, v: Self::V); + + /// Remove value for given key. + /// + /// It is OK to remove a key even when it does not exist. + fn remove(&mut self, k: &Self::K); + + /// Check if backend is empty. + fn is_empty(&self) -> bool; + + /// Return backend as [`Any`] which can be used to downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; +} + +impl CacheBackend for Box +where + T: CacheBackend + ?Sized + 'static, +{ + type K = T::K; + type V = T::V; + + fn get(&mut self, k: &Self::K) -> Option { + self.as_mut().get(k) + } + + fn set(&mut self, k: Self::K, v: Self::V) { + self.as_mut().set(k, v) + } + + fn remove(&mut self, k: &Self::K) { + self.as_mut().remove(k) + } + + fn is_empty(&self) -> bool { + self.as_ref().is_empty() + } + + fn as_any(&self) -> &dyn Any { + self as &dyn Any + } +} diff --git a/cache_system/src/backend/policy/integration_tests.rs b/cache_system/src/backend/policy/integration_tests.rs new file mode 100644 index 0000000..c99a2d0 --- /dev/null +++ b/cache_system/src/backend/policy/integration_tests.rs @@ -0,0 +1,599 @@ +//! Test integration between different policies. + +use std::{collections::HashMap, sync::Arc, time::Duration}; + +use iox_time::{MockProvider, Time}; +use parking_lot::Mutex; +use rand::rngs::mock::StepRng; +use test_helpers::maybe_start_logging; +use tokio::{runtime::Handle, sync::Notify}; + +use crate::{ + backend::{ + policy::refresh::test_util::{backoff_cfg, NotifyExt}, + CacheBackend, + }, + loader::test_util::TestLoader, + resource_consumption::{test_util::TestSize, ResourceEstimator}, +}; + +use super::{ + lru::{LruPolicy, ResourcePool}, + refresh::{test_util::TestRefreshDurationProvider, RefreshPolicy}, + remove_if::{RemoveIfHandle, RemoveIfPolicy}, + ttl::{test_util::TestTtlProvider, TtlPolicy}, + PolicyBackend, +}; + +#[tokio::test] +async fn test_refresh_can_prevent_expiration() { + let TestStateTtlAndRefresh { + mut backend, + refresh_duration_provider, + ttl_provider, + time_provider, + loader, + notify_idle, + .. + } = TestStateTtlAndRefresh::new(); + + loader.mock_next(1, String::from("foo")); + + refresh_duration_provider.set_refresh_in( + 1, + String::from("a"), + Some(backoff_cfg(Duration::from_secs(1))), + ); + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(2))); + + refresh_duration_provider.set_refresh_in(1, String::from("foo"), None); + ttl_provider.set_expires_in(1, String::from("foo"), Some(Duration::from_secs(2))); + + backend.set(1, String::from("a")); + + // perform refresh + time_provider.inc(Duration::from_secs(1)); + notify_idle.notified_with_timeout().await; + + // no expired because refresh resets the timer + time_provider.inc(Duration::from_secs(1)); + assert_eq!(backend.get(&1), Some(String::from("foo"))); + + // we don't request a 2nd refresh (refresh duration is None), so this finally expires + time_provider.inc(Duration::from_secs(1)); + assert_eq!(backend.get(&1), None); +} + +#[tokio::test] +async fn test_refresh_sets_new_expiration_after_it_finishes() { + let TestStateTtlAndRefresh { + mut backend, + refresh_duration_provider, + ttl_provider, + time_provider, + loader, + notify_idle, + .. + } = TestStateTtlAndRefresh::new(); + + let barrier = loader.block_next(1, String::from("foo")); + + refresh_duration_provider.set_refresh_in( + 1, + String::from("a"), + Some(backoff_cfg(Duration::from_secs(1))), + ); + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(3))); + + refresh_duration_provider.set_refresh_in(1, String::from("foo"), None); + ttl_provider.set_expires_in(1, String::from("foo"), Some(Duration::from_secs(3))); + + backend.set(1, String::from("a")); + + // perform refresh + time_provider.inc(Duration::from_secs(1)); + notify_idle.notified_with_timeout().await; + + time_provider.inc(Duration::from_secs(1)); + barrier.wait().await; + notify_idle.notified_with_timeout().await; + assert_eq!(backend.get(&1), Some(String::from("foo"))); + + // no expired because refresh resets the timer after it was ready (now), not when it started (1s ago) + time_provider.inc(Duration::from_secs(2)); + assert_eq!(backend.get(&1), Some(String::from("foo"))); + + // we don't request a 2nd refresh (refresh duration is None), so this finally expires + time_provider.inc(Duration::from_secs(1)); + assert_eq!(backend.get(&1), None); +} + +#[tokio::test] +async fn test_refresh_does_not_update_lru_time() { + let TestStateLruAndRefresh { + mut backend, + refresh_duration_provider, + size_estimator, + time_provider, + loader, + notify_idle, + pool, + .. + } = TestStateLruAndRefresh::new(); + + size_estimator.mock_size(1, String::from("a"), TestSize(4)); + size_estimator.mock_size(1, String::from("foo"), TestSize(4)); + size_estimator.mock_size(2, String::from("b"), TestSize(4)); + size_estimator.mock_size(3, String::from("c"), TestSize(4)); + + refresh_duration_provider.set_refresh_in( + 1, + String::from("a"), + Some(backoff_cfg(Duration::from_secs(1))), + ); + refresh_duration_provider.set_refresh_in(1, String::from("foo"), None); + refresh_duration_provider.set_refresh_in(2, String::from("b"), None); + refresh_duration_provider.set_refresh_in(3, String::from("c"), None); + + let barrier = loader.block_next(1, String::from("foo")); + backend.set(1, String::from("a")); + pool.wait_converged().await; + + // trigger refresh + time_provider.inc(Duration::from_secs(1)); + + time_provider.inc(Duration::from_secs(1)); + backend.set(2, String::from("b")); + pool.wait_converged().await; + + time_provider.inc(Duration::from_secs(1)); + + notify_idle.notified_with_timeout().await; + barrier.wait().await; + notify_idle.notified_with_timeout().await; + + // add a third item to the cache, forcing LRU to evict one of the items + backend.set(3, String::from("c")); + pool.wait_converged().await; + + // Should evict `1` even though it was refreshed after `2` was added + assert_eq!(backend.get(&1), None); +} + +#[tokio::test] +async fn test_if_refresh_to_slow_then_expire() { + let TestStateTtlAndRefresh { + mut backend, + refresh_duration_provider, + ttl_provider, + time_provider, + loader, + notify_idle, + .. + } = TestStateTtlAndRefresh::new(); + + let barrier = loader.block_next(1, String::from("foo")); + refresh_duration_provider.set_refresh_in( + 1, + String::from("a"), + Some(backoff_cfg(Duration::from_secs(1))), + ); + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(2))); + backend.set(1, String::from("a")); + + // perform refresh + time_provider.inc(Duration::from_secs(1)); + notify_idle.notified_with_timeout().await; + + time_provider.inc(Duration::from_secs(1)); + notify_idle.not_notified().await; + assert_eq!(backend.get(&1), None); + + // late loader finish will NOT bring the entry back + barrier.wait().await; + notify_idle.notified_with_timeout().await; + assert_eq!(backend.get(&1), None); +} + +#[tokio::test] +async fn test_refresh_can_trigger_lru_eviction() { + maybe_start_logging(); + + let TestStateLRUAndRefresh { + mut backend, + refresh_duration_provider, + loader, + size_estimator, + time_provider, + notify_idle, + pool, + .. + } = TestStateLRUAndRefresh::new(); + + assert_eq!(pool.limit(), TestSize(10)); + + loader.mock_next(1, String::from("b")); + + refresh_duration_provider.set_refresh_in( + 1, + String::from("a"), + Some(backoff_cfg(Duration::from_secs(1))), + ); + refresh_duration_provider.set_refresh_in(1, String::from("b"), None); + refresh_duration_provider.set_refresh_in(2, String::from("c"), None); + refresh_duration_provider.set_refresh_in(3, String::from("d"), None); + + size_estimator.mock_size(1, String::from("a"), TestSize(1)); + size_estimator.mock_size(1, String::from("b"), TestSize(9)); + size_estimator.mock_size(2, String::from("c"), TestSize(1)); + size_estimator.mock_size(3, String::from("d"), TestSize(1)); + + backend.set(1, String::from("a")); + backend.set(2, String::from("c")); + backend.set(3, String::from("d")); + pool.wait_converged().await; + assert_eq!(backend.get(&2), Some(String::from("c"))); + assert_eq!(backend.get(&3), Some(String::from("d"))); + time_provider.inc(Duration::from_millis(1)); + assert_eq!(backend.get(&1), Some(String::from("a"))); + + // refresh + time_provider.inc(Duration::from_secs(10)); + notify_idle.notified_with_timeout().await; + pool.wait_converged().await; + + // needed to evict 2->"c" + assert_eq!(backend.get(&1), Some(String::from("b"))); + assert_eq!(backend.get(&2), None); + assert_eq!(backend.get(&3), Some(String::from("d"))); +} + +#[tokio::test] +async fn test_lru_learns_about_ttl_evictions() { + let TestStateTtlAndLRU { + mut backend, + ttl_provider, + size_estimator, + time_provider, + pool, + .. + } = TestStateTtlAndLRU::new().await; + + assert_eq!(pool.limit(), TestSize(10)); + + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1))); + ttl_provider.set_expires_in(2, String::from("b"), None); + ttl_provider.set_expires_in(3, String::from("c"), None); + + size_estimator.mock_size(1, String::from("a"), TestSize(4)); + size_estimator.mock_size(2, String::from("b"), TestSize(4)); + size_estimator.mock_size(3, String::from("c"), TestSize(4)); + + backend.set(1, String::from("a")); + backend.set(2, String::from("b")); + + assert_eq!(pool.current(), TestSize(8)); + + // evict + time_provider.inc(Duration::from_secs(1)); + assert_eq!(backend.get(&1), None); + + // now there's space for 3->"c" + assert_eq!(pool.current(), TestSize(4)); + backend.set(3, String::from("c")); + + assert_eq!(pool.current(), TestSize(8)); + assert_eq!(backend.get(&1), None); + assert_eq!(backend.get(&2), Some(String::from("b"))); + assert_eq!(backend.get(&3), Some(String::from("c"))); +} + +#[tokio::test] +async fn test_remove_if_check_does_not_extend_lifetime() { + let TestStateLruAndRemoveIf { + mut backend, + size_estimator, + time_provider, + remove_if_handle, + pool, + .. + } = TestStateLruAndRemoveIf::new().await; + + size_estimator.mock_size(1, String::from("a"), TestSize(4)); + size_estimator.mock_size(2, String::from("b"), TestSize(4)); + size_estimator.mock_size(3, String::from("c"), TestSize(4)); + + backend.set(1, String::from("a")); + pool.wait_converged().await; + time_provider.inc(Duration::from_secs(1)); + + backend.set(2, String::from("b")); + pool.wait_converged().await; + time_provider.inc(Duration::from_secs(1)); + + // Checking remove_if should not count as a "use" of 1 + // for the "least recently used" calculation + remove_if_handle.remove_if(&1, |_| false); + backend.set(3, String::from("c")); + pool.wait_converged().await; + + // adding "c" totals 12 size, but backend has room for only 10 + // so "least recently used" (in this case 1, not 2) should be removed + assert_eq!(backend.get(&1), None); + assert!(backend.get(&2).is_some()); +} + +/// Test setup that integrates the TTL policy with a refresh. +struct TestStateTtlAndRefresh { + backend: PolicyBackend, + ttl_provider: Arc, + refresh_duration_provider: Arc, + time_provider: Arc, + loader: Arc>, + notify_idle: Arc, +} + +impl TestStateTtlAndRefresh { + fn new() -> Self { + let refresh_duration_provider = Arc::new(TestRefreshDurationProvider::new()); + let ttl_provider = Arc::new(TestTtlProvider::new()); + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let metric_registry = metric::Registry::new(); + let loader = Arc::new(TestLoader::default()); + let notify_idle = Arc::new(Notify::new()); + + // set up "RNG" that always generates the maximum, so we can test things easier + let rng_overwrite = StepRng::new(u64::MAX, 0); + + let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend.add_policy(RefreshPolicy::new_inner( + Arc::clone(&time_provider) as _, + Arc::clone(&refresh_duration_provider) as _, + Arc::clone(&loader) as _, + "my_cache", + &metric_registry, + Arc::clone(¬ify_idle), + &Handle::current(), + Some(rng_overwrite), + )); + backend.add_policy(TtlPolicy::new( + Arc::clone(&ttl_provider) as _, + "my_cache", + &metric_registry, + )); + + Self { + backend, + refresh_duration_provider, + ttl_provider, + time_provider, + loader, + notify_idle, + } + } +} + +/// Test setup that integrates the LRU policy with a refresh. +struct TestStateLRUAndRefresh { + backend: PolicyBackend, + size_estimator: Arc, + refresh_duration_provider: Arc, + time_provider: Arc, + loader: Arc>, + pool: Arc>, + notify_idle: Arc, +} + +impl TestStateLRUAndRefresh { + fn new() -> Self { + let refresh_duration_provider = Arc::new(TestRefreshDurationProvider::new()); + let size_estimator = Arc::new(TestSizeEstimator::default()); + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let metric_registry = Arc::new(metric::Registry::new()); + let loader = Arc::new(TestLoader::default()); + let notify_idle = Arc::new(Notify::new()); + + // set up "RNG" that always generates the maximum, so we can test things easier + let rng_overwrite = StepRng::new(u64::MAX, 0); + + let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend.add_policy(RefreshPolicy::new_inner( + Arc::clone(&time_provider) as _, + Arc::clone(&refresh_duration_provider) as _, + Arc::clone(&loader) as _, + "my_cache", + &metric_registry, + Arc::clone(¬ify_idle), + &Handle::current(), + Some(rng_overwrite), + )); + let pool = Arc::new(ResourcePool::new( + "my_pool", + TestSize(10), + Arc::clone(&metric_registry), + &Handle::current(), + )); + backend.add_policy(LruPolicy::new( + Arc::clone(&pool), + "my_cache", + Arc::clone(&size_estimator) as _, + )); + + Self { + backend, + refresh_duration_provider, + size_estimator, + time_provider, + loader, + pool, + notify_idle, + } + } +} + +/// Test setup that integrates the TTL policy with LRU. +struct TestStateTtlAndLRU { + backend: PolicyBackend, + ttl_provider: Arc, + time_provider: Arc, + size_estimator: Arc, + pool: Arc>, +} + +impl TestStateTtlAndLRU { + async fn new() -> Self { + let ttl_provider = Arc::new(TestTtlProvider::new()); + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let metric_registry = Arc::new(metric::Registry::new()); + let size_estimator = Arc::new(TestSizeEstimator::default()); + + let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend.add_policy(TtlPolicy::new( + Arc::clone(&ttl_provider) as _, + "my_cache", + &metric_registry, + )); + let pool = Arc::new(ResourcePool::new( + "my_pool", + TestSize(10), + Arc::clone(&metric_registry), + &Handle::current(), + )); + backend.add_policy(LruPolicy::new( + Arc::clone(&pool), + "my_cache", + Arc::clone(&size_estimator) as _, + )); + + Self { + backend, + ttl_provider, + time_provider, + size_estimator, + pool, + } + } +} + +/// Test setup that integrates the LRU policy with RemoveIf and max size of 10 +struct TestStateLruAndRemoveIf { + backend: PolicyBackend, + time_provider: Arc, + size_estimator: Arc, + remove_if_handle: RemoveIfHandle, + pool: Arc>, +} + +impl TestStateLruAndRemoveIf { + async fn new() -> Self { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let metric_registry = Arc::new(metric::Registry::new()); + let size_estimator = Arc::new(TestSizeEstimator::default()); + + let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + + let pool = Arc::new(ResourcePool::new( + "my_pool", + TestSize(10), + Arc::clone(&metric_registry), + &Handle::current(), + )); + backend.add_policy(LruPolicy::new( + Arc::clone(&pool), + "my_cache", + Arc::clone(&size_estimator) as _, + )); + + let (constructor, remove_if_handle) = + RemoveIfPolicy::create_constructor_and_handle("my_cache", &metric_registry); + backend.add_policy(constructor); + + Self { + backend, + time_provider, + size_estimator, + remove_if_handle, + pool, + } + } +} + +/// Test setup that integrates the LRU policy with a refresh. +struct TestStateLruAndRefresh { + backend: PolicyBackend, + size_estimator: Arc, + refresh_duration_provider: Arc, + time_provider: Arc, + loader: Arc>, + notify_idle: Arc, + pool: Arc>, +} + +impl TestStateLruAndRefresh { + fn new() -> Self { + let refresh_duration_provider = Arc::new(TestRefreshDurationProvider::new()); + let size_estimator = Arc::new(TestSizeEstimator::default()); + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let metric_registry = Arc::new(metric::Registry::new()); + let loader = Arc::new(TestLoader::default()); + let notify_idle = Arc::new(Notify::new()); + + // set up "RNG" that always generates the maximum, so we can test things easier + let rng_overwrite = StepRng::new(u64::MAX, 0); + + let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend.add_policy(RefreshPolicy::new_inner( + Arc::clone(&time_provider) as _, + Arc::clone(&refresh_duration_provider) as _, + Arc::clone(&loader) as _, + "my_cache", + &metric_registry, + Arc::clone(¬ify_idle), + &Handle::current(), + Some(rng_overwrite), + )); + + let pool = Arc::new(ResourcePool::new( + "my_pool", + TestSize(10), + Arc::clone(&metric_registry), + &Handle::current(), + )); + backend.add_policy(LruPolicy::new( + Arc::clone(&pool), + "my_cache", + Arc::clone(&size_estimator) as _, + )); + + Self { + backend, + refresh_duration_provider, + size_estimator, + time_provider, + loader, + notify_idle, + pool, + } + } +} + +#[derive(Debug, Default)] +struct TestSizeEstimator { + sizes: Mutex>, +} + +impl TestSizeEstimator { + fn mock_size(&self, k: u8, v: String, s: TestSize) { + self.sizes.lock().insert((k, v), s); + } +} + +impl ResourceEstimator for TestSizeEstimator { + type K = u8; + type V = String; + type S = TestSize; + + fn consumption(&self, k: &Self::K, v: &Self::V) -> Self::S { + *self.sizes.lock().get(&(*k, v.clone())).unwrap() + } +} diff --git a/cache_system/src/backend/policy/lru.rs b/cache_system/src/backend/policy/lru.rs new file mode 100644 index 0000000..4f5c9ab --- /dev/null +++ b/cache_system/src/backend/policy/lru.rs @@ -0,0 +1,2055 @@ +//! LRU (Least Recently Used) cache system. +//! +//! # Usage +//! +//! ``` +//! # tokio::runtime::Runtime::new().unwrap().block_on(async { +//! use std::{ +//! collections::HashMap, +//! ops::{Add, Sub}, +//! sync::Arc, +//! }; +//! use iox_time::SystemProvider; +//! use cache_system::{ +//! backend::{ +//! CacheBackend, +//! policy::{ +//! lru::{LruPolicy, ResourcePool}, +//! PolicyBackend, +//! }, +//! }, +//! resource_consumption::{Resource, ResourceEstimator}, +//! }; +//! use tokio::runtime::Handle; +//! +//! // first we implement a strongly-typed RAM size measurement +//! #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +//! struct RamSize(usize); +//! +//! impl Resource for RamSize { +//! fn zero() -> Self { +//! Self(0) +//! } +//! +//! fn unit() -> &'static str { +//! "bytes" +//! } +//! } +//! +//! impl From for u64 { +//! fn from(s: RamSize) -> Self { +//! s.0 as Self +//! } +//! } +//! +//! impl Add for RamSize { +//! type Output = Self; +//! +//! fn add(self, rhs: Self) -> Self::Output { +//! Self(self.0.checked_add(rhs.0).expect("overflow")) +//! } +//! } +//! +//! impl Sub for RamSize { +//! type Output = Self; +//! +//! fn sub(self, rhs: Self) -> Self::Output { +//! Self(self.0.checked_sub(rhs.0).expect("underflow")) +//! } +//! } +//! +//! // a time provider is required to determine the age of entries +//! let time_provider = Arc::new(SystemProvider::new()); +//! +//! // registry to capture metrics emitted by the LRU cache +//! let metric_registry = Arc::new(metric::Registry::new()); +//! +//! // set up a memory pool +//! let limit = RamSize(50); +//! let pool = Arc::new(ResourcePool::new( +//! "my_pool", +//! limit, +//! metric_registry, +//! &Handle::current(), +//! )); +//! +//! // set up first pool user: a u64->String map +//! #[derive(Debug)] +//! struct Estimator1 {} +//! +//! impl ResourceEstimator for Estimator1 { +//! type K = u64; +//! type V = String; +//! type S = RamSize; +//! +//! fn consumption(&self, _k: &Self::K, v: &Self::V) -> Self::S { +//! RamSize(8) + RamSize(v.capacity()) +//! } +//! } +//! +//! let mut backend1 = PolicyBackend::new( +//! Box::new(HashMap::new()), +//! Arc::clone(&time_provider) as _, +//! ); +//! backend1.add_policy( +//! LruPolicy::new( +//! Arc::clone(&pool), +//! "id1", +//! Arc::new(Estimator1{}), +//! ) +//! ); +//! +//! // add some data +//! backend1.set(1, String::from("some_entry")); +//! backend1.set(2, String::from("another_entry")); +//! assert_eq!(pool.current(), RamSize(39)); +//! +//! // only test first one +//! assert!(backend1.get(&1).is_some()); +//! +//! // fill up pool +//! backend1.set(3, String::from("this_will_evict_data")); +//! +//! // the policy will eventually evict the data, in tests we can use a help +//! // method to wait for that +//! pool.wait_converged().await; +//! +//! assert!(backend1.get(&1).is_some()); +//! assert!(backend1.get(&2).is_none()); +//! assert!(backend1.get(&3).is_some()); +//! assert_eq!(pool.current(), RamSize(46)); +//! +//! // set up second pool user with totally different types: a u8->Vec map +//! #[derive(Debug)] +//! struct Estimator2 {} +//! +//! impl ResourceEstimator for Estimator2 { +//! type K = u8; +//! type V = Vec; +//! type S = RamSize; +//! +//! fn consumption(&self, _k: &Self::K, v: &Self::V) -> Self::S { +//! RamSize(1) + RamSize(v.capacity()) +//! } +//! } +//! +//! let mut backend2 = PolicyBackend::new( +//! Box::new(HashMap::new()), +//! time_provider, +//! ); +//! backend2.add_policy( +//! LruPolicy::new( +//! Arc::clone(&pool), +//! "id2", +//! Arc::new(Estimator2{}), +//! ) +//! ); +//! +//! // eviction works for all pool members +//! backend2.set(1, vec![1, 2, 3, 4]); +//! pool.wait_converged().await; +//! assert!(backend1.get(&1).is_none()); +//! assert!(backend1.get(&2).is_none()); +//! assert!(backend1.get(&3).is_some()); +//! assert!(backend2.get(&1).is_some()); +//! assert_eq!(pool.current(), RamSize(33)); +//! # }); +//! ``` +//! +//! # Internals +//! Here we describe the internals of the LRU cache system. +//! +//! ## Requirements +//! To understand the construction, we first must understand what the LRU system tries to achieve: +//! +//! - **Single Pool:** Have a single resource pool for multiple LRU backends. +//! - **Eviction Cascade:** Adding data to any of the backends (or modifying an existing entry) should check if there is +//! enough space left in the LRU backend. If not, we must EVENTUALLY remove the least recently used entries over all +//! backends (including the one that just got a new entry) until there is enough space. +//! +//! This has the following consequences: +//! +//! - **Cyclic Structure:** The LRU backends communicate with the pool, but the pool also needs to communicate with +//! all the backends. This creates some form of cyclic data structure. +//! - **Type Erasure:** The pool is only specific to the resource type, not the key and value types of the +//! participating backends. So at some place we need to perform type erasure. +//! +//! ## Data Structures +//! +//! ```text +//! .~~~~~~~~~~~~~~~~. +//! +---------------------------------------: CallbackHandle : +//! | : : +//! | .~~~~~~~~~~~~~~~~. +//! | ^ +//! | .~~~~~~~~~~~~~~~~~. | +//! | : AddressableHeap : | +//! | : : (mutex) +//! | .~~~~~~~~~~~~~~~~~. | +//! | ^ | +//! | | | +//! V (mutex) | +//! .~~~~~~~~~~~~~~~. .~~~~~~~~~~~. | .~~~~~~~~~~~~~~~~. .~~~~~~~~~~~~. +//! -->: PolicyBackend :--->: LruPolicy : | : PoolMemberImpl : : PoolMember : +//! : : : : | : : : : +//! : : : : +------: :<--(dyn)---: : +//! .~~~~~~~~~~~~~~~. .~~~~~~~~~~~. .~~~~~~~~~~~~~~~~. .~~~~~~~~~~~~. +//! | | ^ ^ +//! | | | | +//! | +--------------------------------------(arc)-----+ | +//! (arc) | +//! | (weak) +//! V | +//! .~~~~~~~~~~~~~~. .~~~~~~~~~~~~~. +//! ---------------------->: ResourcePool :-----+-------(arc)--------------------->: SharedState : +//! : : | : : +//! .~~~~~~~~~~~~~~. | .~~~~~~~~~~~~~. +//! | | +//! (handle) | +//! | | +//! V | +//! .~~~~~~~~~~~~~~~. | +//! : clean_up_loop :----+ +//! : : +//! .~~~~~~~~~~~~~~~. +//! ``` +//! +//! ## State +//! State is held in the following structures: +//! +//! - `LruPolicyInner`: Holds [`CallbackHandle`] as well as an [`AddressableHeap`] to +//! memorize when entries were used for the last time. +//! - `ResourcePoolInner`: Holds a reference to all pool members as well as the current consumption. +//! +//! All other structures and traits "only" act as glue. +//! +//! ## Locking +//! What and how we lock depends on the operation. +//! +//! Note that all locks are bare mutexes, there are no read-write-locks. "Only read" is not really an important use +//! case since even `get` requires updating the "last used" timestamp of the corresponding entry. +//! +//! ### Get +//! For [`GET`] we only need to update the "last used" timestamp for the affected entry. No +//! pool-wide operations are required. We update [`AddressableHeap`] and then perform the read operation of the inner +//! backend. +//! +//! ### Remove +//! For [`REMOVE`] the pool usage can only decrease, so other backends are never affected. We +//! first lock [`AddressableHeap`] and check if the entry is present. If it is, we also the "current" counter in +//! [`SharedState`] and then perform the modification on both. +//! +//! ### Set +//! [`SET`] locks [`AddressableHeap`] to figure out if th item exists. If it does, it locks the "current" counter in +//! [`SharedState`] and removes the old value. Then it updates [`AddressableHeap`] with the new value and locks&updates +//! the "current" counter in [`SharedState`] again. It then notifies the clean-up loop that there was an up. +//! +//! Note that in case of an override, the existing "last used" time will be used instead of "now", because just +//! replacing an existing value (e.g. via a [refresh]) should not count as a use. +//! +//! ### Clean-up Loop +//! This is the beefy bit. First it locks and reads the "current" counter in [`SharedState`]. It instantly unlocks the +//! value to not block all pool members adding new values while it we figure out what to evict. Then it selects victims +//! one by one by asking the individual pool members what they could remove. This shortly locks their +//! [`AddressableHeap`]s (one member at the time). After enough victims where selected for eviction, it will delete in +//! them one pool member at the time. Each pool member will lock their [`CallbackHandle`] and when the deletion happens +//! also their [`AddressableHeap`] and the "current" counter in [`SharedState`]. However the lock order is identical to +//! a normal "remove" operation. +//! +//! Note that the clean up loop does not directly update the "current" counter in [`SharedState`] since the "remove" +//! routine already does that. +//! +//! ## Consistency +//! This system is eventually consistent and we are a bit loose at a few places to make it more efficient and easier to +//! implement. This subsection explains cases where this could be visible to an observer. +//! +//! ### Overcommit +//! Since we add new data to the cache pool and the clean-up loop will eventually evict data, we overcommit the pool for +//! a short time. In practice however we already allocated the memory before adding it to the pool. +//! +//! There is a another risk that the cached users will add data so fast that the clean-up loop cannot keep up. This +//! however is highly unlikely, since the loop selects enough victims to get the resource usage below the limit and +//! deletes these victims in batches. The more it runs behind, the large the batch will be. +//! +//! ### Overdelete +//! Similar to "overcommit", it is possible that the clean-up loop deletes more items than necessary. This can happen +//! when between victim selection and actual deletion, entries are removed from the cache (e.g. via [TTL]). However the +//! timing for that is very tight and we would have deleted the data anyways if the delete would have happened a tiny +//! bit later, so in reality this is not a concern. On the other hand, the effect might also be a cache miss that was +//! not strictly necessary and in turn worse performance than we could have had. +//! +//! ### Victim-Use-Delete +//! It is possible that a key is used between victim selection and its removal. In theory we should not remove the key +//! in this case because its no longer "least recently used". However if the key usage would have occurred only a bit +//! later, we would have removed the key anyways so this tight race has no practical meaning. No user can rely on such +//! tight timings and the fullness of a cache pool. +//! +//! ### Victim-Downsize-Delete +//! A selected victim might be replaced with a smaller one between victim selection and its deletion. In this case, the +//! clean-up loop does not delete enough data in its current try but needs an additional iteration. In reality this is +//! very unlikely since most cached entries rarely shrink and even if they do, the clean-up loop will eventually catch +//! up again. +//! +//! +//! [`GET`]: Subscriber::get +//! [`PolicyBackend`]: super::PolicyBackend +//! [refresh]: super::refresh +//! [`REMOVE`]: Subscriber::remove +//! [`SET`]: Subscriber::set +//! [TTL]: super::ttl +use std::{ + any::Any, + collections::{btree_map::Entry, BTreeMap, BinaryHeap}, + fmt::Debug, + hash::Hash, + sync::{Arc, Weak}, +}; + +use iox_time::Time; +use metric::{U64Counter, U64Gauge}; +use observability_deps::tracing::trace; +use ouroboros::self_referencing; +use parking_lot::Mutex; +use tokio::{runtime::Handle, sync::Notify, task::JoinSet}; + +use crate::{ + addressable_heap::{AddressableHeap, AddressableHeapIter}, + backend::CacheBackend, + resource_consumption::{Resource, ResourceEstimator}, +}; + +use super::{CallbackHandle, ChangeRequest, Subscriber}; + +/// Wrapper around something that can be converted into `u64` +/// to enable emitting metrics. +#[derive(Debug)] +struct MeasuredT +where + S: Resource, +{ + v: S, + metric: U64Gauge, +} + +impl MeasuredT +where + S: Resource, +{ + fn new(v: S, metric: U64Gauge) -> Self { + metric.set(v.into()); + + Self { v, metric } + } + + fn inc(&mut self, delta: &S) { + self.v = self.v + *delta; + self.metric.inc((*delta).into()); + } + + fn dec(&mut self, delta: &S) { + self.v = self.v - *delta; + self.metric.dec((*delta).into()); + } +} + +/// Shared state between [`ResourcePool`] and [`clean_up_loop`]. +#[derive(Debug)] +struct SharedState +where + S: Resource, +{ + /// Resource limit. + limit: MeasuredT, + + /// Current resource usage. + current: Mutex>, + + /// Members (= backends) that use this pool. + members: Mutex>>>, + + /// Notification when [`current`](Self::current) as changed. + change_notify: Notify, +} + +impl SharedState +where + S: Resource, +{ + /// Get current members. + /// + /// This also performs a clean-up. + fn members(&self) -> BTreeMap<&'static str, Arc>> { + let mut members = self.members.lock(); + let mut out = BTreeMap::new(); + + members.retain(|id, member| match member.upgrade() { + Some(member) => { + out.insert(*id, member); + true + } + None => false, + }); + + out + } +} + +/// Resource pool. +/// +/// This can be used with [`LruPolicy`]. +#[derive(Debug)] +pub struct ResourcePool +where + S: Resource, +{ + /// Name of the pool. + name: &'static str, + + /// Shared state. + shared: Arc>, + + /// Metric registry associated with the pool. + /// + /// This is used to generate member-specific metrics as well. + metric_registry: Arc, + + /// Background task. + _background_task: JoinSet<()>, + + /// Notification when the background worker is idle, so tests know that the state has converged and that they can + /// continue working. + #[allow(dead_code)] + notify_idle_test_side: + tokio::sync::mpsc::UnboundedSender>, +} + +impl ResourcePool +where + S: Resource, +{ + /// Creates new empty resource pool with given limit. + pub fn new( + name: &'static str, + limit: S, + metric_registry: Arc, + runtime_handle: &Handle, + ) -> Self { + let metric_limit = metric_registry + .register_metric::("cache_lru_pool_limit", "Limit of the LRU resource pool") + .recorder(&[("unit", S::unit()), ("pool", name)]); + let limit = MeasuredT::new(limit, metric_limit); + + let metric_current = metric_registry + .register_metric::( + "cache_lru_pool_usage", + "Current consumption of the LRU resource pool", + ) + .recorder(&[("unit", S::unit()), ("pool", name)]); + let current = Mutex::new(MeasuredT::new(S::zero(), metric_current)); + + let shared = Arc::new(SharedState { + limit, + current, + members: Default::default(), + change_notify: Default::default(), + }); + + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + + let mut background_task = JoinSet::new(); + background_task.spawn_on(clean_up_loop(Arc::clone(&shared), rx), runtime_handle); + + Self { + name, + shared, + metric_registry, + _background_task: background_task, + notify_idle_test_side: tx, + } + } + + /// Get pool limit. + pub fn limit(&self) -> S { + self.shared.limit.v + } + + /// Get current pool usage. + pub fn current(&self) -> S { + self.shared.current.lock().v + } + + /// Register new pool member. + /// + /// # Panic + /// Panics when a member with the specific ID is already registered. + fn register_member(&self, id: &'static str, member: Weak>) { + let mut members = self.shared.members.lock(); + + match members.entry(id) { + Entry::Vacant(v) => { + v.insert(member); + } + Entry::Occupied(mut o) => { + if o.get().strong_count() > 0 { + panic!("Member '{}' already registered", o.key()); + } else { + *o.get_mut() = member; + } + } + } + } + + /// Add used resource from pool. + fn add(&self, s: S) { + let mut current = self.shared.current.lock(); + current.inc(&s); + if current.v > self.shared.limit.v { + self.shared.change_notify.notify_one(); + } + } + + /// Remove used resource from pool. + fn remove(&self, s: S) { + self.shared.current.lock().dec(&s); + } + + /// Wait for the pool to converge to a steady state. + /// + /// This usually means that the background worker that runs the eviction loop is idle. + /// + /// # Panic + /// Panics if the background worker is not idle within 5s or if the worker died. + pub async fn wait_converged(&self) { + let (tx, rx) = futures::channel::oneshot::channel(); + self.notify_idle_test_side + .send(tx) + .expect("background worker alive"); + tokio::time::timeout(std::time::Duration::from_secs(5), rx) + .await + .unwrap() + .unwrap(); + } +} + +/// Cache policy that wraps another backend and limits its resource usage. +#[derive(Debug)] +pub struct LruPolicy +where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, + S: Resource, +{ + /// Link to central resource pool. + pool: Arc>, + + /// Pool member + member: Arc>, + + /// Resource estimator that is used for new (via [`SET`](Subscriber::set)) entries. + resource_estimator: Arc>, + + /// Count number of elements within this specific pool member. + metric_count: U64Gauge, + + /// Count resource usage of this specific pool member. + metric_usage: U64Gauge, +} + +impl LruPolicy +where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, + S: Resource, +{ + /// Create new backend w/o any known keys. + /// + /// The inner backend MUST NOT contain any data at this point, otherwise we will not track any resource consumption + /// for these entries. + /// + /// # Panic + /// - Panics if the given ID is already used within the given pool. + /// - If the inner backend is not empty. + pub fn new( + pool: Arc>, + id: &'static str, + resource_estimator: Arc>, + ) -> impl FnOnce(CallbackHandle) -> Self { + let metric_count = pool + .metric_registry + .register_metric::( + "cache_lru_member_count", + "Number of entries for a given LRU cache pool member", + ) + .recorder(&[("pool", pool.name), ("member", id)]); + let metric_usage = pool + .metric_registry + .register_metric::( + "cache_lru_member_usage", + "Resource usage of a given LRU cache pool member", + ) + .recorder(&[("pool", pool.name), ("member", id), ("unit", S::unit())]); + let metric_evicted = pool + .metric_registry + .register_metric::( + "cache_lru_member_evicted", + "Number of entries that were evicted from a given LRU cache pool member", + ) + .recorder(&[("pool", pool.name), ("member", id)]); + + move |mut callback_handle| { + callback_handle.execute_requests(vec![ChangeRequest::ensure_empty()]); + + let member = Arc::new(PoolMemberImpl { + id, + last_used: Arc::new(Mutex::new(AddressableHeap::new())), + metric_evicted, + callback_handle: Mutex::new(callback_handle), + }); + + pool.register_member(id, Arc::downgrade(&member) as _); + + Self { + pool, + member, + resource_estimator, + metric_count, + metric_usage, + } + } + } +} + +impl Drop for LruPolicy +where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, + S: Resource, +{ + fn drop(&mut self) { + let size_total = { + let mut guard = self.member.last_used.lock(); + let mut accu = S::zero(); + while let Some((_k, s, _t)) = guard.pop() { + accu = accu + s; + } + accu + }; + self.pool.remove(size_total); + } +} + +impl Subscriber for LruPolicy +where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, + S: Resource, +{ + type K = K; + type V = V; + + fn get(&mut self, k: &Self::K, now: Time) -> Vec> { + trace!(?k, now = now.timestamp_nanos(), "LRU get",); + let mut last_used = self.member.last_used.lock(); + + // update "last used" + last_used.update_order(k, now); + + vec![] + } + + fn set( + &mut self, + k: &Self::K, + v: &Self::V, + now: Time, + ) -> Vec> { + trace!(?k, now = now.timestamp_nanos(), "LRU set",); + + // determine all attributes before getting any locks + let consumption = self.resource_estimator.consumption(k, v); + + // "last used" time for new entry + // Note: this might be updated if the entry already exists + let mut last_used_t = now; + + // check for oversized entries + if consumption > self.pool.shared.limit.v { + return vec![ChangeRequest::remove(k.clone())]; + } + + { + let mut last_used = self.member.last_used.lock(); + + // maybe clean from pool + if let Some((consumption, last_used_t_previously)) = last_used.remove(k) { + self.pool.remove(consumption); + self.metric_count.dec(1); + self.metric_usage.dec(consumption.into()); + last_used_t = last_used_t_previously; + } + + // add new entry to inner backend BEFORE adding it to the pool, because the we can overcommit for a short + // time and we want to give the pool a chance to also evict the new resource + last_used.insert(k.clone(), consumption, last_used_t); + self.metric_count.inc(1); + self.metric_usage.inc(consumption.into()); + } + + // pool-wide operation + // Since this may wake-up the background worker and cause evictions, drop the `last_used` lock before doing this (see + // block above) to avoid lock contention. + self.pool.add(consumption); + + vec![] + } + + fn remove(&mut self, k: &Self::K, now: Time) -> Vec> { + trace!(?k, now = now.timestamp_nanos(), "LRU remove",); + let mut last_used = self.member.last_used.lock(); + + if let Some((consumption, _last_used)) = last_used.remove(k) { + self.pool.remove(consumption); + self.metric_count.dec(1); + self.metric_usage.dec(consumption.into()); + } + + vec![] + } +} + +/// Iterator for enumerating removal candidates of a [`PoolMember`]. +/// +/// This is type-erased to make [`PoolMember`] object-safe. +type PoolMemberCouldRemove = Box)>>; + +/// A member of a [`ResourcePool`]/[`SharedState`]. +/// +/// The only implementation of this is [`PoolMemberImpl`]. This indirection is required to erase `K` and `V` from specific +/// backend so we can stick it into the generic pool. +trait PoolMember: Debug + Send + Sync + 'static { + /// Resource type. + type S; + + /// Check if this member has anything that could be removed. + /// + /// If so, return: + /// - "last used" timestamp + /// - resource consumption of that entry + /// - type-erased key + /// + /// Elements are returned in order of the "last used" timestamp, in increasing order. + fn could_remove(&self) -> PoolMemberCouldRemove; + + /// Remove given set of keys. + /// + /// The keys MUST be a result of [`could_remove`](Self::could_remove), otherwise the downcasting may not work and panic. + fn remove_keys(&self, keys: Vec>); +} + +/// The only implementation of [`PoolMember`]. +/// +/// In contrast to the trait, this still contains `K` and `V`. +#[derive(Debug)] +pub struct PoolMemberImpl +where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, + S: Resource, +{ + /// Pool member ID. + id: &'static str, + + /// Count number of evicted items. + metric_evicted: U64Counter, + + /// Tracks usage of the last used elements. + /// + /// See documentation of [`callback_handle`](Self::callback_handle) for a reasoning about locking. + last_used: Arc>>, + + /// Handle to call back into the [`PolicyBackend`] to evict data. + /// + /// # Locking + /// This MUST NOT share a lock with [`last_used`](Self::last_used) because otherwise we would deadlock during + /// eviction: + /// + /// 1. [`remove_keys`](PoolMember::remove_keys) + /// 2. lock both [`callback_handle`](Self::callback_handle) and [`last_used`](Self::last_used) + /// 3. [`CallbackHandle::execute_requests`] + /// 4. [`Subscriber::remove`] + /// 5. need to lock [`last_used`](Self::last_used) again + /// + /// + /// [`PolicyBackend`]: super::PolicyBackend + callback_handle: Mutex>, +} + +impl PoolMember for PoolMemberImpl +where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, + S: Resource, +{ + type S = S; + + fn could_remove(&self) -> Box)>> { + it::build_it(self.last_used.lock_arc()) + } + + fn remove_keys(&self, keys: Vec>) { + let keys = keys + .into_iter() + .map(|k| *k.downcast::().expect("wrong type")) + .collect::>(); + + trace!( + id = self.id, + ?keys, + "evicting cache entries due to LRU pressure", + ); + self.metric_evicted.inc(keys.len() as u64); + + let combined = ChangeRequest::from_fn(move |backend| { + for k in keys { + backend.remove(&k); + } + }); + + self.callback_handle.lock().execute_requests(vec![combined]); + } +} + +/// Helper module that wraps the iterator handling for [`PoolMember`]/[`PoolMemberImpl`]. +/// +/// This is required because [`ouroboros`] generates a bunch of code that we do not want to leak all over the place. +mod it { + // ignore some lints for the ouroboros codegen + #![allow(clippy::future_not_send)] + + use super::*; + + /// The lock that we need to generate a candidate iterator. + pub type Lock = + parking_lot::lock_api::ArcMutexGuard>; + + #[self_referencing] + struct PoolMemberIter + where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + S: Resource, + { + lock: Lock, + + #[borrows(lock)] + #[covariant] + it: AddressableHeapIter<'this, K, S, Time>, + } + + impl Iterator for PoolMemberIter + where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + S: Resource, + { + type Item = (Time, S, Box); + + fn next(&mut self) -> Option { + self.with_it_mut(|it| { + it.next() + .map(|(k, s, t)| (*t, *s, Box::new(k.clone()) as _)) + }) + } + + fn size_hint(&self) -> (usize, Option) { + self.borrow_it().size_hint() + } + } + + /// Build iterator. + pub fn build_it(lock: Lock) -> PoolMemberCouldRemove + where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + S: Resource, + { + Box::new( + PoolMemberIterBuilder { + lock, + it_builder: |lock| lock.iter(), + } + .build(), + ) + } +} + +/// Background worker that eventually cleans up data if the pool reaches capacity. +/// +/// This method NEVER returns. +async fn clean_up_loop( + shared: Arc>, + mut notify_idle_worker_side: tokio::sync::mpsc::UnboundedReceiver< + futures::channel::oneshot::Sender<()>, + >, +) where + S: Resource, +{ + 'outer: loop { + // yield to tokio so that the runtime has a chance to abort this function during shutdown + tokio::task::yield_now().await; + + // get current value but drop the lock immediately + // Especially we must NOT hold the lock when we later execute the change requests, otherwise there will be two + // lock direction: + // - someone adding new resource: member -> pool + // - clean up loop: pool -> memeber + let mut current = { + let guard = shared.current.lock(); + guard.v + }; + + if current <= shared.limit.v { + // nothing to do, sleep and then continue w/ next round + loop { + tokio::select! { + // biased sleep so we can notify test hooks if we're idle + biased; + + _ = shared.change_notify.notified() => {continue 'outer;}, + + idle_notify = notify_idle_worker_side.recv() => { + if let Some(n) = idle_notify { + n.send(()).ok(); + } + }, + } + } + } + + // receive members + // Do NOT hold the member lock during the deletion later because this can lead to deadlocks during shutdown. + let members = shared.members(); + if members.is_empty() { + // early retry, there's nothing we can do + continue; + } + + // select victims + let mut victims: BTreeMap<&'static str, Vec>> = Default::default(); + { + trace!( + current = current.into(), + limit = shared.limit.v.into(), + "select eviction victims" + ); + + // limit scope of member iterators, because they contain locks and we MUST drop them before proceeding to + // the actual deletion + let mut heap: BinaryHeap> = members + .iter() + .map(|(id, member)| EvictionCandidateIter::new(id, member.could_remove())) + .collect(); + + while current > shared.limit.v { + let candidate = heap.pop().expect("checked that we have at least 1 member"); + let (candidate, victim) = candidate.next(); + + match victim { + Some((t, s, k)) => { + trace!( + id = candidate.id, + s = s.into(), + t_ns = t.timestamp_nanos(), + "found victim" + ); + current = current - s; + victims.entry(candidate.id).or_default().push(k); + } + None => { + // The custom `Ord` implementation ensures that we prefer iterators with data over iterators + // without any candidates. So if the "best" iterators has NO candidates, this means that ALL + // iterators are empty. + // + // Or in other words: some data was deleted between retrieving the "current" value and locking + // the iterators. This is fine, just stop looping and remove the victims that we have selected + // so far. + trace!("no more data"); + break; + } + } + + heap.push(candidate); + } + + trace!("done selecting eviction victims"); + } + + for (id, keys) in victims { + let member = members.get(id).expect("did get this ID from this map"); + member.remove_keys(keys); + } + } +} + +/// Current element presented by the [`EvictionCandidateIter`]. +type EvictionCandidate = Option<(Time, S, Box)>; + +/// Wraps a [`PoolMember`] so we can compare it in a "tournament" to find out what data to evict. +struct EvictionCandidateIter +where + S: Resource, +{ + id: &'static str, + it: PoolMemberCouldRemove, + current: EvictionCandidate, +} + +impl EvictionCandidateIter +where + S: Resource, +{ + fn new(id: &'static str, mut it: PoolMemberCouldRemove) -> Self { + let current = it.next(); + Self { id, it, current } + } + + /// Get next eviction candidate. + /// + /// This advances the internal state so that this iterator compares correctly afterwards. + fn next(mut self) -> (Self, EvictionCandidate) { + let mut tmp = self.it.next(); + std::mem::swap(&mut tmp, &mut self.current); + (self, tmp) + } +} + +impl PartialEq for EvictionCandidateIter +where + S: Resource, +{ + fn eq(&self, other: &Self) -> bool { + match (self.current.as_ref(), other.current.as_ref()) { + (None, None) | (Some(_), None) | (None, Some(_)) => false, + (Some((t1, s1, _k1)), Some((t2, s2, _k2))) => (t1, s1) == (t2, s2), + } + } +} + +impl Eq for EvictionCandidateIter where S: Resource {} + +impl PartialOrd for EvictionCandidateIter +where + S: Resource, +{ + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for EvictionCandidateIter +where + S: Resource, +{ + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + // Note: reverse order because iterators are kept in a MAX heap + match (self.current.as_ref(), other.current.as_ref()) { + (None, None) => { + // break tie + self.id.cmp(other.id).reverse() + } + + // prefer iterators with candidates over empty iterators + (Some(_), None) => std::cmp::Ordering::Greater, + (None, Some(_)) => std::cmp::Ordering::Less, + + (Some((t1, _s1, _k1)), Some((t2, _s2, _k2))) => { + // compare by time, break tie using member ID + (t1, self.id).cmp(&(t2, other.id)).reverse() + } + } + } +} + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, time::Duration}; + + use iox_time::{MockProvider, SystemProvider}; + use metric::{Observation, RawReporter}; + use test_helpers::maybe_start_logging; + + use crate::{ + backend::{policy::PolicyBackend, CacheBackend}, + resource_consumption::test_util::TestSize, + }; + + use super::*; + + #[tokio::test] + #[should_panic(expected = "inner backend is not empty")] + async fn test_panic_inner_not_empty() { + let time_provider = Arc::new(MockProvider::new(Time::from_timestamp_nanos(0))); + let pool = Arc::new(ResourcePool::new( + "pool", + TestSize(10), + Arc::new(metric::Registry::new()), + &Handle::current(), + )); + let resource_estimator = Arc::new(TestResourceEstimator {}); + + let mut backend = PolicyBackend::hashmap_backed(time_provider); + let policy_constructor = LruPolicy::new( + Arc::clone(&pool), + "id", + Arc::clone(&resource_estimator) as _, + ); + backend.add_policy(|mut callback_handle| { + callback_handle.execute_requests(vec![ChangeRequest::set(String::from("foo"), 1usize)]); + policy_constructor(callback_handle) + }) + } + + #[tokio::test] + #[should_panic(expected = "Member 'id' already registered")] + async fn test_panic_id_collision() { + let time_provider = Arc::new(MockProvider::new(Time::from_timestamp_nanos(0))); + let pool = Arc::new(ResourcePool::new( + "pool", + TestSize(10), + Arc::new(metric::Registry::new()), + &Handle::current(), + )); + let resource_estimator = Arc::new(TestResourceEstimator {}); + + let mut backend1 = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend1.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id", + Arc::clone(&resource_estimator) as _, + )); + + let mut backend2 = PolicyBackend::hashmap_backed(time_provider); + backend2.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id", + Arc::clone(&resource_estimator) as _, + )); + } + + #[tokio::test] + async fn test_reregister_member() { + let time_provider = Arc::new(MockProvider::new(Time::from_timestamp_nanos(0))); + let pool = Arc::new(ResourcePool::new( + "pool", + TestSize(10), + Arc::new(metric::Registry::new()), + &Handle::current(), + )); + let resource_estimator = Arc::new(TestResourceEstimator {}); + + let mut backend1 = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend1.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id", + Arc::clone(&resource_estimator) as _, + )); + backend1.set(String::from("a"), 1usize); + assert_eq!(pool.current(), TestSize(1)); + + // drop the backend so re-registering the same ID ("id") MUST NOT panic + drop(backend1); + assert_eq!(pool.current(), TestSize(0)); + + let mut backend2 = PolicyBackend::hashmap_backed(time_provider); + backend2.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id", + Arc::clone(&resource_estimator) as _, + )); + backend2.set(String::from("a"), 2usize); + assert_eq!(pool.current(), TestSize(2)); + } + + #[tokio::test] + async fn test_empty() { + let time_provider = Arc::new(MockProvider::new(Time::from_timestamp_nanos(0))); + let pool = Arc::new(ResourcePool::new( + "pool", + TestSize(10), + Arc::new(metric::Registry::new()), + &Handle::current(), + )); + let resource_estimator = Arc::new(TestResourceEstimator {}); + + assert_eq!(pool.current().0, 0); + + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id1", + Arc::clone(&resource_estimator) as _, + )); + + assert_eq!(pool.current().0, 0); + } + + #[tokio::test] + async fn test_double_set() { + let time_provider = Arc::new(MockProvider::new(Time::from_timestamp_nanos(0))); + let pool = Arc::new(ResourcePool::new( + "pool", + TestSize(2), + Arc::new(metric::Registry::new()), + &Handle::current(), + )); + let resource_estimator = Arc::new(TestResourceEstimator {}); + + let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id1", + Arc::clone(&resource_estimator) as _, + )); + + backend.set(String::from("a"), 1usize); + time_provider.inc(Duration::from_millis(1)); + + backend.set(String::from("b"), 1usize); + time_provider.inc(Duration::from_millis(1)); + + // does NOT count as "used" + backend.set(String::from("a"), 1usize); + time_provider.inc(Duration::from_millis(1)); + + backend.set(String::from("c"), 1usize); + pool.wait_converged().await; + + assert_eq!(backend.get(&String::from("a")), None); + } + + #[tokio::test] + async fn test_override() { + let time_provider = Arc::new(MockProvider::new(Time::from_timestamp_nanos(0))); + let pool = Arc::new(ResourcePool::new( + "pool", + TestSize(10), + Arc::new(metric::Registry::new()), + &Handle::current(), + )); + let resource_estimator = Arc::new(TestResourceEstimator {}); + + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id1", + Arc::clone(&resource_estimator) as _, + )); + + backend.set(String::from("a"), 5usize); + assert_eq!(pool.current().0, 5); + + backend.set(String::from("b"), 3usize); + assert_eq!(pool.current().0, 8); + + backend.set(String::from("a"), 4usize); + assert_eq!(pool.current().0, 7); + } + + #[tokio::test] + async fn test_remove() { + let time_provider = Arc::new(MockProvider::new(Time::from_timestamp_nanos(0))); + let pool = Arc::new(ResourcePool::new( + "pool", + TestSize(10), + Arc::new(metric::Registry::new()), + &Handle::current(), + )); + let resource_estimator = Arc::new(TestResourceEstimator {}); + + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id1", + Arc::clone(&resource_estimator) as _, + )); + + backend.set(String::from("a"), 5usize); + assert_eq!(pool.current().0, 5); + + backend.set(String::from("b"), 3usize); + assert_eq!(pool.current().0, 8); + + backend.remove(&String::from("a")); + assert_eq!(pool.current().0, 3); + + assert_eq!(backend.get(&String::from("a")), None); + assert_inner_backend(&mut backend, [(String::from("b"), 3)]); + + // removing it again should just work + backend.remove(&String::from("a")); + assert_eq!(pool.current().0, 3); + } + + #[tokio::test] + async fn test_eviction_order() { + maybe_start_logging(); + + let time_provider = Arc::new(MockProvider::new(Time::from_timestamp_nanos(0))); + let pool = Arc::new(ResourcePool::new( + "pool", + TestSize(21), + Arc::new(metric::Registry::new()), + &Handle::current(), + )); + let resource_estimator = Arc::new(TestResourceEstimator {}); + + let mut backend1 = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend1.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id1", + Arc::clone(&resource_estimator) as _, + )); + + let mut backend2 = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend2.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id2", + Arc::clone(&resource_estimator) as _, + )); + + backend1.set(String::from("b"), 1usize); + backend2.set(String::from("a"), 2usize); + backend1.set(String::from("a"), 3usize); + backend1.set(String::from("c"), 4usize); + assert_eq!(pool.current().0, 10); + + time_provider.inc(Duration::from_millis(1)); + + backend1.set(String::from("d"), 5usize); + assert_eq!(pool.current().0, 15); + + time_provider.inc(Duration::from_millis(1)); + backend2.set(String::from("b"), 6usize); + assert_eq!(pool.current().0, 21); + + time_provider.inc(Duration::from_millis(1)); + + // now are exactly at capacity + pool.wait_converged().await; + assert_inner_backend( + &mut backend1, + [ + (String::from("a"), 3), + (String::from("b"), 1), + (String::from("c"), 4), + (String::from("d"), 5), + ], + ); + assert_inner_backend( + &mut backend2, + [(String::from("a"), 2), (String::from("b"), 6)], + ); + + // adding a single element will drop the smallest key from the first backend (by ID) + backend1.set(String::from("foo1"), 1usize); + pool.wait_converged().await; + assert_eq!(pool.current().0, 19); + assert_inner_backend( + &mut backend1, + [ + (String::from("b"), 1), + (String::from("c"), 4), + (String::from("d"), 5), + (String::from("foo1"), 1), + ], + ); + assert_inner_backend( + &mut backend2, + [(String::from("a"), 2), (String::from("b"), 6)], + ); + + // now we can fill up data up to the capacity again + backend1.set(String::from("foo2"), 2usize); + pool.wait_converged().await; + assert_eq!(pool.current().0, 21); + assert_inner_backend( + &mut backend1, + [ + (String::from("b"), 1), + (String::from("c"), 4), + (String::from("d"), 5), + (String::from("foo1"), 1), + (String::from("foo2"), 2), + ], + ); + assert_inner_backend( + &mut backend2, + [(String::from("a"), 2), (String::from("b"), 6)], + ); + + // can evict two keys at the same time + backend1.set(String::from("foo3"), 2usize); + pool.wait_converged().await; + assert_eq!(pool.current().0, 18); + assert_inner_backend( + &mut backend1, + [ + (String::from("d"), 5), + (String::from("foo1"), 1), + (String::from("foo2"), 2), + (String::from("foo3"), 2), + ], + ); + assert_inner_backend( + &mut backend2, + [(String::from("a"), 2), (String::from("b"), 6)], + ); + + // can evict from another backend + backend1.set(String::from("foo4"), 4usize); + pool.wait_converged().await; + assert_eq!(pool.current().0, 20); + assert_inner_backend( + &mut backend1, + [ + (String::from("d"), 5), + (String::from("foo1"), 1), + (String::from("foo2"), 2), + (String::from("foo3"), 2), + (String::from("foo4"), 4), + ], + ); + assert_inner_backend(&mut backend2, [(String::from("b"), 6)]); + + // can evict multiple timestamps + backend1.set(String::from("foo5"), 7usize); + pool.wait_converged().await; + assert_eq!(pool.current().0, 16); + assert_inner_backend( + &mut backend1, + [ + (String::from("foo1"), 1), + (String::from("foo2"), 2), + (String::from("foo3"), 2), + (String::from("foo4"), 4), + (String::from("foo5"), 7), + ], + ); + assert_inner_backend(&mut backend2, []); + } + + #[tokio::test] + async fn test_get_updates_last_used() { + let time_provider = Arc::new(MockProvider::new(Time::from_timestamp_nanos(0))); + let pool = Arc::new(ResourcePool::new( + "pool", + TestSize(6), + Arc::new(metric::Registry::new()), + &Handle::current(), + )); + let resource_estimator = Arc::new(TestResourceEstimator {}); + + let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id1", + Arc::clone(&resource_estimator) as _, + )); + + backend.set(String::from("a"), 1usize); + backend.set(String::from("b"), 2usize); + + time_provider.inc(Duration::from_millis(1)); + + backend.set(String::from("c"), 3usize); + pool.wait_converged().await; + + time_provider.inc(Duration::from_millis(1)); + + assert_eq!(backend.get(&String::from("a")), Some(1usize)); + + assert_eq!(pool.current().0, 6); + assert_inner_backend( + &mut backend, + [ + (String::from("a"), 1), + (String::from("b"), 2), + (String::from("c"), 3), + ], + ); + + backend.set(String::from("foo"), 3usize); + pool.wait_converged().await; + assert_eq!(pool.current().0, 4); + assert_inner_backend( + &mut backend, + [(String::from("a"), 1), (String::from("foo"), 3)], + ); + } + + #[tokio::test] + async fn test_oversized_entries() { + let time_provider = Arc::new(MockProvider::new(Time::from_timestamp_nanos(0))); + let pool = Arc::new(ResourcePool::new( + "pool", + TestSize(10), + Arc::new(metric::Registry::new()), + &Handle::current(), + )); + let resource_estimator = Arc::new(TestResourceEstimator {}); + + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id1", + Arc::clone(&resource_estimator) as _, + )); + + backend.set(String::from("a"), 1usize); + pool.wait_converged().await; + backend.set(String::from("b"), 11usize); + pool.wait_converged().await; + + // "a" did NOT get evicted. Instead we removed the oversized entry straight away. + assert_eq!(pool.current().0, 1); + assert_inner_backend(&mut backend, [(String::from("a"), 1)]); + } + + #[tokio::test] + async fn test_values_are_dropped() { + let time_provider = Arc::new(MockProvider::new(Time::from_timestamp_nanos(0))); + let pool = Arc::new(ResourcePool::new( + "pool", + TestSize(3), + Arc::new(metric::Registry::new()), + &Handle::current(), + )); + + #[derive(Debug)] + struct Provider {} + + impl ResourceEstimator for Provider { + type K = Arc; + type V = Arc; + type S = TestSize; + + fn consumption(&self, _k: &Self::K, v: &Self::V) -> Self::S { + TestSize(*v.as_ref()) + } + } + + let resource_estimator = Arc::new(Provider {}); + + let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id1", + Arc::clone(&resource_estimator) as _, + )); + + let k1 = Arc::new(String::from("a")); + let v1 = Arc::new(2usize); + let k2 = Arc::new(String::from("b")); + let v2 = Arc::new(2usize); + let k1_weak = Arc::downgrade(&k1); + let v1_weak = Arc::downgrade(&v1); + + backend.set(k1, v1); + pool.wait_converged().await; + + time_provider.inc(Duration::from_millis(1)); + + backend.set(k2, v2); + pool.wait_converged().await; + + assert_eq!(k1_weak.strong_count(), 0); + assert_eq!(v1_weak.strong_count(), 0); + } + + #[tokio::test] + async fn test_backends_are_dropped() { + let time_provider = Arc::new(MockProvider::new(Time::from_timestamp_nanos(0))); + let pool = Arc::new(ResourcePool::new( + "pool", + TestSize(3), + Arc::new(metric::Registry::new()), + &Handle::current(), + )); + + let resource_estimator = Arc::new(TestResourceEstimator {}); + + #[derive(Debug)] + struct Backend { + #[allow(dead_code)] + marker: Arc<()>, + inner: HashMap, + } + + impl CacheBackend for Backend { + type K = String; + type V = usize; + + fn get(&mut self, k: &Self::K) -> Option { + self.inner.get(k).copied() + } + + fn set(&mut self, k: Self::K, v: Self::V) { + self.inner.set(k, v) + } + + fn remove(&mut self, k: &Self::K) { + self.inner.remove(k); + } + + fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + fn as_any(&self) -> &dyn Any { + self as &dyn Any + } + } + + let marker = Arc::new(()); + let marker_weak = Arc::downgrade(&marker); + + let mut backend = PolicyBackend::new( + Box::new(Backend { + marker, + inner: HashMap::new(), + }), + Arc::clone(&time_provider) as _, + ); + backend.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id1", + Arc::clone(&resource_estimator) as _, + )); + backend.set(String::from("a"), 2usize); + + drop(backend); + assert_eq!(marker_weak.strong_count(), 0); + } + + #[tokio::test] + async fn test_metrics() { + let time_provider = Arc::new(MockProvider::new(Time::from_timestamp_nanos(0))); + let metric_registry = Arc::new(metric::Registry::new()); + let pool = Arc::new(ResourcePool::new( + "pool", + TestSize(10), + Arc::clone(&metric_registry), + &Handle::current(), + )); + let resource_estimator = Arc::new(TestResourceEstimator {}); + + let mut reporter = RawReporter::default(); + metric_registry.report(&mut reporter); + assert_eq!( + reporter + .metric("cache_lru_pool_limit") + .unwrap() + .observation(&[("pool", "pool"), ("unit", "bytes")]) + .unwrap(), + &Observation::U64Gauge(10) + ); + assert_eq!( + reporter + .metric("cache_lru_pool_usage") + .unwrap() + .observation(&[("pool", "pool"), ("unit", "bytes")]) + .unwrap(), + &Observation::U64Gauge(0) + ); + + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id", + Arc::clone(&resource_estimator) as _, + )); + + let mut reporter = RawReporter::default(); + metric_registry.report(&mut reporter); + assert_eq!( + reporter + .metric("cache_lru_pool_limit") + .unwrap() + .observation(&[("pool", "pool"), ("unit", "bytes")]) + .unwrap(), + &Observation::U64Gauge(10) + ); + assert_eq!( + reporter + .metric("cache_lru_pool_usage") + .unwrap() + .observation(&[("pool", "pool"), ("unit", "bytes")]) + .unwrap(), + &Observation::U64Gauge(0) + ); + assert_eq!( + reporter + .metric("cache_lru_member_count") + .unwrap() + .observation(&[("pool", "pool"), ("member", "id")]) + .unwrap(), + &Observation::U64Gauge(0) + ); + assert_eq!( + reporter + .metric("cache_lru_member_usage") + .unwrap() + .observation(&[("pool", "pool"), ("unit", "bytes"), ("member", "id")]) + .unwrap(), + &Observation::U64Gauge(0) + ); + assert_eq!( + reporter + .metric("cache_lru_member_evicted") + .unwrap() + .observation(&[("pool", "pool"), ("member", "id")]) + .unwrap(), + &Observation::U64Counter(0) + ); + + backend.set(String::from("a"), 1usize); // usage = 1 + pool.wait_converged().await; + backend.set(String::from("b"), 2usize); // usage = 3 + pool.wait_converged().await; + backend.set(String::from("b"), 3usize); // usage = 4 + pool.wait_converged().await; + backend.set(String::from("c"), 4usize); // usage = 8 + pool.wait_converged().await; + backend.set(String::from("d"), 3usize); // usage = 10 (evicted "a") + pool.wait_converged().await; + backend.remove(&String::from("c")); // usage = 6 + pool.wait_converged().await; + + let mut reporter = RawReporter::default(); + metric_registry.report(&mut reporter); + assert_eq!( + reporter + .metric("cache_lru_pool_limit") + .unwrap() + .observation(&[("pool", "pool"), ("unit", "bytes")]) + .unwrap(), + &Observation::U64Gauge(10) + ); + assert_eq!( + reporter + .metric("cache_lru_pool_usage") + .unwrap() + .observation(&[("pool", "pool"), ("unit", "bytes")]) + .unwrap(), + &Observation::U64Gauge(6) + ); + assert_eq!( + reporter + .metric("cache_lru_member_count") + .unwrap() + .observation(&[("pool", "pool"), ("member", "id")]) + .unwrap(), + &Observation::U64Gauge(2), // b and d + ); + assert_eq!( + reporter + .metric("cache_lru_member_usage") + .unwrap() + .observation(&[("pool", "pool"), ("unit", "bytes"), ("member", "id")]) + .unwrap(), + &Observation::U64Gauge(6) + ); + assert_eq!( + reporter + .metric("cache_lru_member_evicted") + .unwrap() + .observation(&[("pool", "pool"), ("member", "id")]) + .unwrap(), + &Observation::U64Counter(1) + ); + } + + /// A note regarding the test flavor: + /// + /// The main generic test function is not async, so the background clean-up would never fire because we don't + /// yield to tokio. The test will pass in both cases (w/ a single worker and w/ multiple), however if the + /// background worker is a actually doing anything it might be a more realistic test case. + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_generic_backend() { + use crate::backend::test_util::test_generic; + + #[derive(Debug)] + struct ZeroSizeProvider {} + + impl ResourceEstimator for ZeroSizeProvider { + type K = u8; + type V = String; + type S = TestSize; + + fn consumption(&self, _k: &Self::K, _v: &Self::V) -> Self::S { + TestSize(0) + } + } + + test_generic(|| { + let time_provider = Arc::new(MockProvider::new(Time::from_timestamp_nanos(0))); + let pool = Arc::new(ResourcePool::new( + "pool", + TestSize(10), + Arc::new(metric::Registry::new()), + &Handle::current(), + )); + let resource_estimator = Arc::new(ZeroSizeProvider {}); + + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id", + Arc::clone(&resource_estimator) as _, + )); + backend + }); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_deadlock() { + // Regression test for . + test_deadlock_inner(Duration::from_secs(1)).await; + + // Regression test for + for _ in 0..100 { + test_deadlock_inner(Duration::from_millis(1)).await; + } + } + + async fn test_deadlock_inner(test_duration: Duration) { + #[derive(Debug)] + struct OneSizeProvider {} + + impl ResourceEstimator for OneSizeProvider { + type K = u128; + type V = (); + type S = TestSize; + + fn consumption(&self, _k: &Self::K, _v: &Self::V) -> Self::S { + TestSize(1) + } + } + + let time_provider = Arc::new(SystemProvider::new()) as _; + let pool = Arc::new(ResourcePool::new( + "pool", + TestSize(100), + Arc::new(metric::Registry::new()), + &Handle::current(), + )); + let resource_estimator = Arc::new(OneSizeProvider {}); + + let mut backend1 = PolicyBackend::hashmap_backed(Arc::clone(&time_provider)); + backend1.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id1", + Arc::clone(&resource_estimator) as _, + )); + + let mut backend2 = PolicyBackend::hashmap_backed(Arc::clone(&time_provider)); + backend2.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id2", + Arc::clone(&resource_estimator) as _, + )); + + let worker1 = tokio::spawn(async move { + let mut counter = 0u128; + loop { + backend1.set(counter, ()); + counter += 2; + tokio::task::yield_now().await; + } + }); + let worker2 = tokio::spawn(async move { + let mut counter = 1u128; + loop { + backend2.set(counter, ()); + counter += 2; + tokio::task::yield_now().await; + } + }); + + tokio::time::sleep(test_duration).await; + + worker1.abort(); + worker2.abort(); + } + + #[tokio::test] + async fn test_efficient_eviction() { + let time_provider = Arc::new(MockProvider::new(Time::from_timestamp_nanos(0))); + let metric_registry = Arc::new(metric::Registry::new()); + let pool = Arc::new(ResourcePool::new( + "pool", + TestSize(10), + Arc::clone(&metric_registry), + &Handle::current(), + )); + let resource_estimator = Arc::new(TestResourceEstimator {}); + + let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id", + Arc::clone(&resource_estimator) as _, + )); + + // fill up pool + for i in 0..10 { + backend.set(i.to_string(), 1usize); + } + assert_eq!(pool.current(), TestSize(10)); + + // evict all members using a single large one + time_provider.inc(Duration::from_millis(1)); + backend.set(String::from("big"), 10usize); + pool.wait_converged().await; + assert_eq!(pool.current(), TestSize(10)); + + let mut reporter = RawReporter::default(); + metric_registry.report(&mut reporter); + assert_eq!( + reporter + .metric("cache_lru_member_evicted") + .unwrap() + .observation(&[("pool", "pool"), ("member", "id")]) + .unwrap(), + // it is important that all 10 items are evicted with a single eviction + &Observation::U64Counter(10) + ); + } + + #[tokio::test] + async fn test_eviction_half_half() { + let time_provider = Arc::new(MockProvider::new(Time::from_timestamp_nanos(0))); + let metric_registry = Arc::new(metric::Registry::new()); + let pool = Arc::new(ResourcePool::new( + "pool", + TestSize(20), + Arc::clone(&metric_registry), + &Handle::current(), + )); + let resource_estimator = Arc::new(TestResourceEstimator {}); + + let mut backend1 = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend1.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id1", + Arc::clone(&resource_estimator) as _, + )); + + let mut backend2 = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend2.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id2", + Arc::clone(&resource_estimator) as _, + )); + + // fill up pool + for i in 0..10 { + backend1.set(i.to_string(), 1usize); + backend2.set(i.to_string(), 1usize); + time_provider.inc(Duration::from_millis(1)); + } + assert_eq!(pool.current(), TestSize(20)); + + // evict members using a single large one + time_provider.inc(Duration::from_millis(1)); + backend1.set(String::from("big"), 10usize); + pool.wait_converged().await; + assert_eq!(pool.current(), TestSize(20)); + + // every member lost 5 entries + // Note: backend1 has 5+1 items because it own the "big" key + assert_inner_len(&mut backend1, 6); + assert_inner_len(&mut backend2, 5); + } + + #[tokio::test] + async fn test_eviction_one_member_all_other_member_some() { + let time_provider = Arc::new(MockProvider::new(Time::from_timestamp_nanos(0))); + let metric_registry = Arc::new(metric::Registry::new()); + let pool = Arc::new(ResourcePool::new( + "pool", + TestSize(3), + Arc::clone(&metric_registry), + &Handle::current(), + )); + let resource_estimator = Arc::new(TestResourceEstimator {}); + + let mut backend1 = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend1.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id1", + Arc::clone(&resource_estimator) as _, + )); + + let mut backend2 = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend2.add_policy(LruPolicy::new( + Arc::clone(&pool), + "id2", + Arc::clone(&resource_estimator) as _, + )); + + // fill up pool + backend1.set(String::from("a"), 1usize); + time_provider.inc(Duration::from_millis(1)); + backend2.set(String::from("a"), 1usize); + time_provider.inc(Duration::from_millis(1)); + backend2.set(String::from("b"), 1usize); + assert_eq!(pool.current(), TestSize(3)); + + // evict members using a single large one + time_provider.inc(Duration::from_millis(1)); + backend2.set(String::from("big"), 2usize); + pool.wait_converged().await; + assert_eq!(pool.current(), TestSize(3)); + + assert_inner_backend(&mut backend1, []); + assert_inner_backend( + &mut backend2, + [(String::from("b"), 1usize), (String::from("big"), 2usize)], + ); + } + + #[derive(Debug)] + struct TestResourceEstimator {} + + impl ResourceEstimator for TestResourceEstimator { + type K = String; + type V = usize; + type S = TestSize; + + fn consumption(&self, _k: &Self::K, v: &Self::V) -> Self::S { + TestSize(*v) + } + } + + #[track_caller] + fn assert_inner_backend( + backend: &mut PolicyBackend, + data: [(String, usize); N], + ) { + let inner_backend = backend.inner_ref(); + let inner_backend = inner_backend + .as_any() + .downcast_ref::>() + .unwrap(); + let expected = HashMap::from(data); + assert_eq!(inner_backend, &expected); + } + + #[track_caller] + fn assert_inner_len(backend: &mut PolicyBackend, len: usize) { + let inner_backend = backend.inner_ref(); + let inner_backend = inner_backend + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(inner_backend.len(), len); + } +} diff --git a/cache_system/src/backend/policy/mod.rs b/cache_system/src/backend/policy/mod.rs new file mode 100644 index 0000000..c503c2a --- /dev/null +++ b/cache_system/src/backend/policy/mod.rs @@ -0,0 +1,1974 @@ +//! Policy framework for [backends](crate::backend::CacheBackend). + +use std::{ + cell::RefCell, + collections::{HashMap, VecDeque}, + fmt::Debug, + hash::Hash, + marker::PhantomData, + ops::Deref, + sync::{Arc, Weak}, +}; + +use iox_time::{Time, TimeProvider}; +use parking_lot::{lock_api::ArcReentrantMutexGuard, RawMutex, RawThreadId, ReentrantMutex}; + +use super::CacheBackend; + +pub mod lru; +pub mod refresh; +pub mod remove_if; +pub mod ttl; + +#[cfg(test)] +mod integration_tests; + +/// Convenience macro to easily follow the borrow/lock chain of [`StrongSharedInner`]. +/// +/// This cannot just be a method because we cannot return references to local variables. +macro_rules! lock_inner { + ($guard:ident = $inner:expr) => { + let $guard = $inner.lock(); + let $guard = $guard.try_borrow_mut().expect("illegal recursive access"); + }; + (mut $guard:ident = $inner:expr) => { + let $guard = $inner.lock(); + let mut $guard = $guard.try_borrow_mut().expect("illegal recursive access"); + }; +} + +/// Backend that is controlled by different policies. +/// +/// # Policies & Recursion +/// +/// Policies have two tasks: +/// +/// - initiate changes (e.g. based on timers) +/// - react to changes +/// +/// Getting data from a [`PolicyBackend`] and feeding data back into it in a somewhat synchronous +/// manner sounds really close to recursion. Uncontrolled recursion however is bad for the +/// following reasons: +/// +/// 1. **Stack space:** We may easily run out of stack space. +/// 2. **Ownership:** Looping back into the same data structure can easily lead to deadlocks (data +/// corruption is luckily prevented by Rust's ownership model). +/// +/// However sometimes we need to have interactions of policies in a "recursive" manner. E.g.: +/// +/// 1. A refresh policies updates a value based on a timer. The value gets bigger. +/// 2. Some resource-pool policy decides that this is now too much data and wants to evict data. +/// 3. The refresh policy gets informed about the values that are removed so it can stop refreshing +/// them. +/// +/// The solution that [`PolicyBackend`] uses is the following: +/// +/// All interaction of the policy with a [`PolicyBackend`] happens through a proxy object called +/// [`ChangeRequest`]. The [`ChangeRequest`] encapsulates a single atomic "transaction" on the +/// underlying store. This can be a simple operation as [`REMOVE`](CacheBackend::remove) but also +/// compound operations like "get+remove" (e.g. to check if a value needs to be pruned from the +/// cache). The policy has two ways of issuing [`ChangeRequest`]s: +/// +/// 1. **Initial / self-driven:** Upon creation the policy receives a [`CallbackHandle`] that it +/// can use initiate requests. This handle must only be used to create requests "out of thin +/// air" (e.g. based on a timer). It MUST NOT be used to react to changes (see next point) to +/// avoid deadlocks. +/// 2. **Reactions:** Each policy implements a [`Subscriber`] that receives notifications for each +/// changes. These notification return [`ChangeRequest`]s that the policy wishes to be +/// performed. This construct is designed to avoid recursion. +/// +/// Also note that a policy that uses the subscriber interface MUST NOT hold locks on their +/// internal data structure while performing _initial requests_ to avoid deadlocks (since the +/// subscriber will be informed about the changes). +/// +/// We cannot guarantee that policies fulfill this interface, but [`PolicyBackend`] performs some +/// sanity checks (e.g. it will catch if the same thread that started an initial requests recurses +/// into another initial request). +/// +/// # Change Propagation +/// +/// Each [`ChangeRequest`] is processed atomically, so "get + set" / "compare + exchange" patterns +/// work as expected. +/// +/// Changes will be propagated "breadth first". This means that the initial changes will form a +/// task list. For every task in this list (front to back), we will execute the [`ChangeRequest`]. +/// Every change that is performed within this request (usually only one) we propagate the change +/// as follows: +/// +/// 1. underlying backend +/// 2. policies (in the order they where added) +/// +/// From step 2 we collect new change requests that will be added to the back of the task list. +/// +/// The original requests will return to the caller once all tasks are completed. +/// +/// When a [`ChangeRequest`] performs multiple operations -- e.g. [`GET`](CacheBackend::get) and +/// [`SET`](CacheBackend::set) -- we first inform all subscribers about the first operation (in +/// this case: [`GET`](CacheBackend::get)) and collect the resulting [`ChangeRequest`]s. Then we +/// process the second operation (in this case: [`SET`](CacheBackend::set)). +/// +/// # `GET` +/// +/// The return value for [`CacheBackend::get`] is fetched from the inner backend AFTER all changes +/// are applied. +/// +/// Note [`ChangeRequest::get`] has no way of returning a result to the [`Subscriber`] that created +/// it. The "changes" solely act as some kind of "keep alive" / "this was used" signal. +/// +/// # Example +/// +/// **The policies in these examples are deliberately silly but simple!** +/// +/// Let's start with a purely reactive policy that will round up all integer values to the next +/// even number: +/// +/// ``` +/// use std::{ +/// collections::HashMap, +/// sync::Arc, +/// }; +/// use cache_system::backend::{ +/// CacheBackend, +/// policy::{ +/// ChangeRequest, +/// PolicyBackend, +/// Subscriber, +/// }, +/// }; +/// use iox_time::{ +/// SystemProvider, +/// Time, +/// }; +/// +/// #[derive(Debug)] +/// struct EvenNumberPolicy; +/// +/// type CR = ChangeRequest<'static, &'static str, u64>; +/// +/// impl Subscriber for EvenNumberPolicy { +/// type K = &'static str; +/// type V = u64; +/// +/// fn set(&mut self, k: &&'static str, v: &u64, _now: Time) -> Vec { +/// // When new key `k` is set to value `v` if `v` is odd, +/// // request a change to set `k` to `v+1` +/// if v % 2 == 1 { +/// vec![CR::set(k, v + 1)] +/// } else { +/// vec![] +/// } +/// } +/// } +/// +/// let mut backend = PolicyBackend::new( +/// Box::new(HashMap::new()), +/// Arc::new(SystemProvider::new()), +/// ); +/// backend.add_policy(|_callback_backend| EvenNumberPolicy); +/// +/// backend.set("foo", 8); +/// backend.set("bar", 9); +/// +/// assert_eq!(backend.get(&"foo"), Some(8)); +/// assert_eq!(backend.get(&"bar"), Some(10)); +/// ``` +/// +/// And here is a more active backend that regularly writes the current system time to a key: +/// +/// ``` +/// use std::{ +/// collections::HashMap, +/// sync::{ +/// Arc, +/// atomic::{AtomicBool, Ordering}, +/// }, +/// thread::{JoinHandle, sleep, spawn}, +/// time::{Duration, Instant}, +/// }; +/// use cache_system::backend::{ +/// CacheBackend, +/// policy::{ +/// ChangeRequest, +/// PolicyBackend, +/// Subscriber, +/// }, +/// }; +/// use iox_time::SystemProvider; +/// +/// #[derive(Debug)] +/// struct NowPolicy { +/// cancel: Arc, +/// join_handle: Option>, +/// }; +/// +/// impl Drop for NowPolicy { +/// fn drop(&mut self) { +/// self.cancel.store(true, Ordering::SeqCst); +/// self.join_handle +/// .take() +/// .expect("worker thread present") +/// .join() +/// .expect("worker thread finished"); +/// } +/// } +/// +/// type CR = ChangeRequest<'static, &'static str, Instant>; +/// +/// impl Subscriber for NowPolicy { +/// type K = &'static str; +/// type V = Instant; +/// } +/// +/// let mut backend = PolicyBackend::new( +/// Box::new(HashMap::new()), +/// Arc::new(SystemProvider::new()), +/// ); +/// backend.add_policy(|mut callback_handle| { +/// let cancel = Arc::new(AtomicBool::new(false)); +/// let cancel_captured = Arc::clone(&cancel); +/// let join_handle = spawn(move || { +/// loop { +/// if cancel_captured.load(Ordering::SeqCst) { +/// break; +/// } +/// callback_handle.execute_requests(vec![ +/// CR::set("now", Instant::now()), +/// ]); +/// sleep(Duration::from_millis(1)); +/// } +/// }); +/// NowPolicy{cancel, join_handle: Some(join_handle)} +/// }); +/// +/// +/// // eventually we should see a key +/// let t_start = Instant::now(); +/// loop { +/// if let Some(t) = backend.get(&"now") { +/// // value should be fresh +/// assert!(t.elapsed() < Duration::from_millis(100)); +/// break; +/// } +/// +/// assert!(t_start.elapsed() < Duration::from_secs(1)); +/// sleep(Duration::from_millis(10)); +/// } +/// ``` +#[derive(Debug)] +pub struct PolicyBackend +where + K: Clone + Eq + Hash + Ord + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + inner: StrongSharedInner, +} + +impl PolicyBackend +where + K: Clone + Eq + Hash + Ord + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + /// Create new backend w/o any policies. + /// + /// # Panic + /// + /// Panics if `inner` is not empty. + pub fn new( + inner: Box + Send>, + time_provider: Arc, + ) -> Self { + assert!(inner.is_empty(), "inner backend is not empty"); + + Self { + inner: Arc::new(ReentrantMutex::new(RefCell::new(PolicyBackendInner { + inner, + subscribers: Vec::new(), + time_provider, + }))), + } + } + + /// Create a new backend with a HashMap as the [`CacheBackend`]. + pub fn hashmap_backed(time_provider: Arc) -> Self { + // See . This clippy lint suggests + // replacing `Box::new(HashMap::new())` with `Box::default()`, which in most cases would be + // shorter, but because this type is actually a `Box`, the replacement would + // need to be `Box::>::default()`, which doesn't seem like an improvement. + #[allow(clippy::box_default)] + Self::new(Box::new(HashMap::new()), Arc::clone(&time_provider)) + } + + /// Adds new policy. + /// + /// See documentation of [`PolicyBackend`] for more information. + /// + /// This is called with a function that receives the "callback backend" to this backend and + /// should return a [`Subscriber`]. This loopy construct was chosen to discourage the leakage + /// of the "callback backend" to any other object. + pub fn add_policy(&mut self, policy_constructor: C) + where + C: FnOnce(CallbackHandle) -> S, + S: Subscriber, + { + let callback_handle = CallbackHandle { + inner: Arc::downgrade(&self.inner), + }; + let subscriber = policy_constructor(callback_handle); + lock_inner!(mut guard = self.inner); + guard.subscribers.push(Box::new(subscriber)); + } + + /// Provide temporary read-only access to the underlying backend. + /// + /// This is mostly useful for debugging and testing. + pub fn inner_ref(&mut self) -> InnerBackendRef<'_, K, V> { + // NOTE: We deliberately use a mutable reference here to prevent users from using `` while we hold a lock to the underlying backend. + + inner_ref::build(Arc::clone(&self.inner)) + } +} + +impl CacheBackend for PolicyBackend +where + K: Clone + Eq + Hash + Ord + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + type K = K; + type V = V; + + fn get(&mut self, k: &Self::K) -> Option { + lock_inner!(mut guard = self.inner); + perform_changes(&mut guard, vec![ChangeRequest::get(k.clone())]); + + // poll inner backend AFTER everything has settled + guard.inner.get(k) + } + + fn set(&mut self, k: Self::K, v: Self::V) { + lock_inner!(mut guard = self.inner); + perform_changes(&mut guard, vec![ChangeRequest::set(k, v)]); + } + + fn remove(&mut self, k: &Self::K) { + lock_inner!(mut guard = self.inner); + perform_changes(&mut guard, vec![ChangeRequest::remove(k.clone())]); + } + + fn is_empty(&self) -> bool { + lock_inner!(guard = self.inner); + guard.inner.is_empty() + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +/// Handle that allows a [`Subscriber`] to send [`ChangeRequest`]s back to the [`PolicyBackend`] +/// that owns that very [`Subscriber`]. +#[derive(Debug)] +pub struct CallbackHandle +where + K: Clone + Eq + Hash + Ord + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + inner: WeakSharedInner, +} + +impl CallbackHandle +where + K: Clone + Eq + Hash + Ord + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + /// Start a series of requests to the [`PolicyBackend`] that is referenced by this handle. + /// + /// This method returns AFTER the requests and all the follow-up changes requested by all + /// policies are played out. You should NOT hold a lock on your policies internal data + /// structures while calling this function if you plan to also [subscribe](Subscriber) to + /// changes because this would easily lead to deadlocks. + pub fn execute_requests(&mut self, change_requests: Vec>) { + let Some(inner) = self.inner.upgrade() else { + // backend gone, can happen during shutdowns, try not to panic + return; + }; + + lock_inner!(mut guard = inner); + perform_changes(&mut guard, change_requests); + } +} + +#[derive(Debug)] +struct PolicyBackendInner +where + K: Clone + Eq + Hash + Ord + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + /// Underlying cache backend. + inner: Box + Send>, + + /// List of subscribers. + subscribers: Vec>>, + + /// Time provider. + time_provider: Arc, +} + +type WeakSharedInner = Weak>>>; +type StrongSharedInner = Arc>>>; + +/// Perform changes breadth first. +fn perform_changes( + inner: &mut PolicyBackendInner, + change_requests: Vec>, +) where + K: Clone + Eq + Hash + Ord + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + let mut tasks = VecDeque::from(change_requests); + let now = inner.time_provider.now(); + + while let Some(change_request) = tasks.pop_front() { + let mut recorder = Recorder { + inner: inner.inner.as_mut(), + records: vec![], + }; + + change_request.eval(&mut recorder); + + for record in recorder.records { + for subscriber in &mut inner.subscribers { + let requests = match &record { + Record::Get { k } => subscriber.get(k, now), + Record::Set { k, v } => subscriber.set(k, v, now), + Record::Remove { k } => subscriber.remove(k, now), + }; + + tasks.extend(requests.into_iter()); + } + } + } +} + +/// Subscriber to change events. +pub trait Subscriber: Debug + Send + 'static { + /// Cache key. + type K: Clone + Eq + Hash + Ord + Debug + Send + 'static; + + /// Cached value. + type V: Clone + Debug + Send + 'static; + + /// Get value for given key if it exists. + /// + /// The current time `now` is provided as a parameter so that all policies and backends use a + /// unified timestamp rather than their own provider, which is more consistent and performant. + fn get(&mut self, _k: &Self::K, _now: Time) -> Vec> { + // do nothing by default + vec![] + } + + /// Set value for given key. + /// + /// It is OK to set and override a key that already exists. + /// + /// The current time `now` is provided as a parameter so that all policies and backends use a + /// unified timestamp rather than their own provider, which is more consistent and performant. + fn set( + &mut self, + _k: &Self::K, + _v: &Self::V, + _now: Time, + ) -> Vec> { + // do nothing by default + vec![] + } + + /// Remove value for given key. + /// + /// It is OK to remove a key even when it does not exist. + /// + /// The current time `now` is provided as a parameter so that all policies and backends use a + /// unified timestamp rather than their own provider, which is more consistent and performant. + fn remove( + &mut self, + _k: &Self::K, + _now: Time, + ) -> Vec> { + // do nothing by default + vec![] + } +} + +/// A change request to a backend. +pub struct ChangeRequest<'a, K, V> +where + K: Clone + Eq + Hash + Ord + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + fun: ChangeRequestFn<'a, K, V>, +} + +impl<'a, K, V> Debug for ChangeRequest<'a, K, V> +where + K: Clone + Eq + Hash + Ord + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CacheRequest").finish_non_exhaustive() + } +} + +impl<'a, K, V> ChangeRequest<'a, K, V> +where + K: Clone + Eq + Hash + Ord + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + /// Custom way of constructing a change request. + /// + /// This is considered a rather low-level function and you should prefer the higher-level + /// constructs like [`get`](Self::get), [`set`](Self::set), and [`remove`](Self::remove). + /// + /// Takes a "callback backend" and can freely act on it. The underlying backend of + /// [`PolicyBackend`] is guaranteed to be locked during a single request, so "get + modify" + /// patterns work out of the box without the need to fear interleaving modifications. + pub fn from_fn(f: F) -> Self + where + F: for<'b, 'c> FnOnce(&'c mut Recorder<'b, K, V>) + 'a, + { + Self { fun: Box::new(f) } + } + + /// [`GET`](CacheBackend::get) + pub fn get(k: K) -> Self { + Self::from_fn(move |backend| { + backend.get(&k); + }) + } + + /// [`SET`](CacheBackend::set) + pub fn set(k: K, v: V) -> Self { + Self::from_fn(move |backend| { + backend.set(k, v); + }) + } + + /// [`REMOVE`](CacheBackend::remove). + pub fn remove(k: K) -> Self { + Self::from_fn(move |backend| { + backend.remove(&k); + }) + } + + /// Ensure that backend is empty and panic otherwise. + /// + /// This is mostly useful during initialization. + pub fn ensure_empty() -> Self { + Self::from_fn(|backend| { + assert!(backend.is_empty(), "inner backend is not empty"); + }) + } + + /// Execute this change request. + pub fn eval(self, backend: &mut Recorder<'_, K, V>) { + (self.fun)(backend); + } +} + +/// Function captured within [`ChangeRequest`]. +type ChangeRequestFn<'a, K, V> = Box FnOnce(&'c mut Recorder<'b, K, V>) + 'a>; + +/// Records of interactions with the callback [`CacheBackend`]. +#[derive(Debug, PartialEq)] +enum Record { + /// [`GET`](CacheBackend::get) + Get { + /// Key. + k: K, + }, + + /// [`SET`](CacheBackend::set) + Set { + /// Key. + k: K, + + /// Value. + v: V, + }, + + /// [`REMOVE`](CacheBackend::remove). + Remove { + /// Key. + k: K, + }, +} + +/// Specialized [`CacheBackend`] that forwards changes and requests to the underlying backend of +/// [`PolicyBackend`] but also records all changes into [`Record`]s. +#[derive(Debug)] +pub struct Recorder<'a, K, V> +where + K: Clone + Eq + Hash + Ord + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + inner: &'a mut (dyn CacheBackend + Send), + records: Vec>, +} + +impl<'a, K, V> Recorder<'a, K, V> +where + K: Clone + Eq + Hash + Ord + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + /// Perform a [`GET`](CacheBackend::get) request that is NOT seen by other policies. + /// + /// This is helpful if you just want to check the underlying data of a key without treating it + /// as "used". + /// + /// Note that this functionality only exists for [`GET`](CacheBackend::get) requests, not for + /// modifying requests like [`SET`](CacheBackend::set) or [`REMOVE`](CacheBackend::remove) + /// since they always require policies to be in-sync. + pub fn get_untracked(&mut self, k: &K) -> Option { + self.inner.get(k) + } +} + +impl<'a, K, V> CacheBackend for Recorder<'a, K, V> +where + K: Clone + Eq + Hash + Ord + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + type K = K; + type V = V; + + fn get(&mut self, k: &Self::K) -> Option { + self.records.push(Record::Get { k: k.clone() }); + self.inner.get(k) + } + + fn set(&mut self, k: Self::K, v: Self::V) { + self.records.push(Record::Set { + k: k.clone(), + v: v.clone(), + }); + self.inner.set(k, v); + } + + fn remove(&mut self, k: &Self::K) { + self.records.push(Record::Remove { k: k.clone() }); + self.inner.remove(k); + } + + fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + fn as_any(&self) -> &dyn std::any::Any { + panic!("don't any-cast the recorder please") + } +} + +/// Helper module that wraps the implementation of [`InnerBackendRef`]. +/// +/// This is required because [`ouroboros`] generates a bunch of code that we do not want to leak all over the place. +mod inner_ref { + #![allow(non_snake_case, clippy::future_not_send)] + + use super::*; + use ouroboros::self_referencing; + + /// Read-only ref to the inner backend of [`PolicyBackend`] for debugging. + #[self_referencing] + pub struct InnerBackendRef<'a, K, V> + where + K: Clone + Eq + Hash + Ord + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, + { + l1: ArcReentrantMutexGuard>>, + #[borrows(l1)] + #[covariant] + l2: std::cell::RefMut<'this, PolicyBackendInner>, + _phantom: PhantomData<&'a mut ()>, + } + + impl<'a, K, V> Deref for InnerBackendRef<'a, K, V> + where + K: Clone + Eq + Hash + Ord + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, + { + type Target = dyn CacheBackend; + + fn deref(&self) -> &Self::Target { + self.borrow_l2().inner.as_ref() + } + } + + pub(super) fn build<'a, K, V>(inner: StrongSharedInner) -> InnerBackendRef<'a, K, V> + where + K: Clone + Eq + Hash + Ord + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, + { + let inner = inner.lock_arc(); + InnerBackendRefBuilder { + l1: inner, + l2_builder: |l1| l1.borrow_mut(), + _phantom: PhantomData, + } + .build() + } +} + +pub use inner_ref::InnerBackendRef; + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, sync::Barrier, thread::JoinHandle}; + + use iox_time::MockProvider; + + use super::*; + + #[allow(dead_code)] + const fn assert_send() {} + const _: () = assert_send::>(); + + #[test] + #[should_panic(expected = "inner backend is not empty")] + fn test_panic_inner_not_empty() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + PolicyBackend::new( + Box::new(HashMap::from([(String::from("foo"), 1usize)])), + time_provider, + ); + } + + #[test] + fn test_generic() { + crate::backend::test_util::test_generic(|| { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + PolicyBackend::hashmap_backed(time_provider) + }) + } + + #[test] + #[should_panic(expected = "test steps left")] + fn test_meta_panic_steps_left() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 1, + }, + action: TestAction::ChangeRequests(vec![]), + }])); + } + + #[test] + #[should_panic(expected = "step left for get operation")] + fn test_meta_panic_requires_condition_get() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![])); + + backend.get(&String::from("a")); + } + + #[test] + #[should_panic(expected = "step left for set operation")] + fn test_meta_panic_requires_condition_set() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![])); + + backend.set(String::from("a"), 2); + } + + #[test] + #[should_panic(expected = "step left for remove operation")] + fn test_meta_panic_requires_condition_remove() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![])); + + backend.remove(&String::from("a")); + } + + #[test] + #[should_panic(expected = "Condition mismatch")] + fn test_meta_panic_checks_condition_get() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![TestStep { + condition: TestBackendInteraction::Get { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![]), + }])); + + backend.get(&String::from("b")); + } + + #[test] + #[should_panic(expected = "Condition mismatch")] + fn test_meta_panic_checks_condition_set() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 1, + }, + action: TestAction::ChangeRequests(vec![]), + }])); + + backend.set(String::from("a"), 2); + } + + #[test] + #[should_panic(expected = "Condition mismatch")] + fn test_meta_panic_checks_condition_remove() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![TestStep { + condition: TestBackendInteraction::Remove { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![]), + }])); + + backend.remove(&String::from("b")); + } + + #[test] + fn test_basic_propagation() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![ + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 1, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("b"), + v: 2, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Remove { + k: String::from("b"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("b"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + ])); + + backend.set(String::from("a"), 1); + backend.set(String::from("b"), 2); + backend.remove(&String::from("b")); + + assert_eq!(backend.get(&String::from("a")), Some(1)); + assert_eq!(backend.get(&String::from("b")), None); + } + + #[test] + #[should_panic(expected = "illegal recursive access")] + fn test_panic_recursion_detection_get() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![TestStep { + condition: TestBackendInteraction::Remove { + k: String::from("a"), + }, + action: TestAction::CallBackendDirectly(TestBackendInteraction::Get { + k: String::from("b"), + }), + }])); + + backend.remove(&String::from("a")); + } + + #[test] + #[should_panic(expected = "illegal recursive access")] + fn test_panic_recursion_detection_set() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![TestStep { + condition: TestBackendInteraction::Remove { + k: String::from("a"), + }, + action: TestAction::CallBackendDirectly(TestBackendInteraction::Set { + k: String::from("b"), + v: 1, + }), + }])); + + backend.remove(&String::from("a")); + } + + #[test] + #[should_panic(expected = "illegal recursive access")] + fn test_panic_recursion_detection_remove() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![TestStep { + condition: TestBackendInteraction::Remove { + k: String::from("a"), + }, + action: TestAction::CallBackendDirectly(TestBackendInteraction::Remove { + k: String::from("b"), + }), + }])); + + backend.remove(&String::from("a")); + } + + #[test] + fn test_get_untracked() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![ + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 1, + }, + action: TestAction::ChangeRequests(vec![SendableChangeRequest::from_fn( + |backend| { + assert_eq!(backend.get_untracked(&String::from("a")), Some(1)); + }, + )]), + }, + // NO `GET` interaction triggered here! + ])); + + backend.set(String::from("a"), 1); + } + + #[test] + fn test_basic_get_set() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![ + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![SendableChangeRequest::set( + String::from("a"), + 1, + )]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 1, + }, + action: TestAction::ChangeRequests(vec![]), + }, + ])); + + assert_eq!(backend.get(&String::from("a")), Some(1)); + } + + #[test] + fn test_basic_get_get() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![ + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![SendableChangeRequest::get(String::from( + "a", + ))]), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + ])); + + assert_eq!(backend.get(&String::from("a")), None); + } + + #[test] + fn test_basic_set_set_get_get() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![ + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 1, + }, + action: TestAction::ChangeRequests(vec![SendableChangeRequest::set( + String::from("b"), + 2, + )]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("b"), + v: 2, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("b"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + ])); + + backend.set(String::from("a"), 1); + + assert_eq!(backend.get(&String::from("a")), Some(1)); + assert_eq!(backend.get(&String::from("b")), Some(2)); + } + + #[test] + fn test_basic_set_remove_get() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![ + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 1, + }, + action: TestAction::ChangeRequests(vec![SendableChangeRequest::remove( + String::from("a"), + )]), + }, + TestStep { + condition: TestBackendInteraction::Remove { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + ])); + + backend.set(String::from("a"), 1); + + assert_eq!(backend.get(&String::from("a")), None); + } + + #[test] + fn test_basic_remove_set_get_get() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![ + TestStep { + condition: TestBackendInteraction::Remove { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![SendableChangeRequest::set( + String::from("b"), + 1, + )]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("b"), + v: 1, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("b"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + ])); + + backend.remove(&String::from("a")); + + assert_eq!(backend.get(&String::from("a")), None); + assert_eq!(backend.get(&String::from("b")), Some(1)); + } + + #[test] + fn test_basic_remove_remove_get_get() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![ + TestStep { + condition: TestBackendInteraction::Remove { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![SendableChangeRequest::remove( + String::from("b"), + )]), + }, + TestStep { + condition: TestBackendInteraction::Remove { + k: String::from("b"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("b"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + ])); + + backend.remove(&String::from("a")); + + assert_eq!(backend.get(&String::from("a")), None); + assert_eq!(backend.get(&String::from("b")), None); + } + + #[test] + fn test_ordering_within_requests_vector() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![ + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 11, + }, + action: TestAction::ChangeRequests(vec![ + SendableChangeRequest::set(String::from("a"), 12), + SendableChangeRequest::set(String::from("a"), 13), + ]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 12, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 13, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + ])); + + backend.set(String::from("a"), 11); + + assert_eq!(backend.get(&String::from("a")), Some(13)); + } + + #[test] + fn test_ordering_across_policies() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![ + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 11, + }, + action: TestAction::ChangeRequests(vec![SendableChangeRequest::set( + String::from("a"), + 12, + )]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 12, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 13, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + ])); + backend.add_policy(create_test_policy(vec![ + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 11, + }, + action: TestAction::ChangeRequests(vec![SendableChangeRequest::set( + String::from("a"), + 13, + )]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 12, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 13, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + ])); + + backend.set(String::from("a"), 11); + + assert_eq!(backend.get(&String::from("a")), Some(13)); + } + + #[test] + fn test_ping_pong() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![ + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 11, + }, + action: TestAction::ChangeRequests(vec![SendableChangeRequest::set( + String::from("a"), + 12, + )]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 12, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 13, + }, + action: TestAction::ChangeRequests(vec![SendableChangeRequest::set( + String::from("a"), + 14, + )]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 14, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + ])); + backend.add_policy(create_test_policy(vec![ + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 11, + }, + action: TestAction::ChangeRequests(vec![SendableChangeRequest::set( + String::from("a"), + 13, + )]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 12, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 13, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 14, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + ])); + + backend.set(String::from("a"), 11); + + assert_eq!(backend.get(&String::from("a")), Some(14)); + } + + #[test] + #[should_panic(expected = "this is a test")] + fn test_meta_multithread_panics_are_propagated() { + let barrier_pre = Arc::new(Barrier::new(2)); + let barrier_post = Arc::new(Barrier::new(1)); + + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 1, + }, + action: TestAction::CallBackendDelayed( + Arc::clone(&barrier_pre), + TestBackendInteraction::Panic, + Arc::clone(&barrier_post), + ), + }])); + + backend.set(String::from("a"), 1); + barrier_pre.wait(); + + // panic on drop + } + + /// Checks that a policy background task can access the "callback backend" without triggering + /// the "illegal recursion" detection. + #[test] + fn test_multithread() { + let barrier_pre = Arc::new(Barrier::new(2)); + let barrier_post = Arc::new(Barrier::new(2)); + + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![ + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 1, + }, + action: TestAction::CallBackendDelayed( + Arc::clone(&barrier_pre), + TestBackendInteraction::Set { + k: String::from("a"), + v: 4, + }, + Arc::clone(&barrier_post), + ), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 2, + }, + action: TestAction::BlockAndChangeRequest( + Arc::clone(&barrier_pre), + vec![SendableChangeRequest::set(String::from("a"), 3)], + ), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 3, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 4, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + ])); + + backend.set(String::from("a"), 1); + assert_eq!(backend.get(&String::from("a")), Some(1)); + + backend.set(String::from("a"), 2); + + barrier_post.wait(); + assert_eq!(backend.get(&String::from("a")), Some(4)); + } + + #[test] + fn test_get_from_policies_are_propagated() { + let barrier_pre = Arc::new(Barrier::new(2)); + let barrier_post = Arc::new(Barrier::new(2)); + + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![ + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 1, + }, + action: TestAction::CallBackendDelayed( + Arc::clone(&barrier_pre), + TestBackendInteraction::Get { + k: String::from("a"), + }, + Arc::clone(&barrier_post), + ), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + ])); + + backend.set(String::from("a"), 1); + barrier_pre.wait(); + barrier_post.wait(); + } + + /// Checks that dropping [`PolicyBackend`] drop the policies as well as the inner backend. + #[test] + fn test_drop() { + let marker_backend = Arc::new(()); + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::new( + Box::new(DropTester(Arc::clone(&marker_backend), ())), + time_provider, + ); + + let marker_policy = Arc::new(()); + backend.add_policy(|callback| DropTester(Arc::clone(&marker_policy), callback)); + + assert_eq!(Arc::strong_count(&marker_backend), 2); + assert_eq!(Arc::strong_count(&marker_policy), 2); + + drop(backend); + + assert_eq!(Arc::strong_count(&marker_backend), 1); + assert_eq!(Arc::strong_count(&marker_policy), 1); + } + + /// We have to ways of handling "compound" [`ChangeRequest`]s, i.e. requests that perform + /// multiple operations: + /// + /// 1. We could loop over the operations and inner-loop over the policies to collect reactions + /// 2. We could loop over all the policies and present each polices all operations in one go + /// + /// We've decided to chose option 1. This test ensures that by setting up a compound request + /// (reacting to `set("a", 11)`) with a compound of two operations (`set("a", 12)`, `set("a", + /// 13)`) which we call `C1` and `C2` (for "compound 1 and 2"). The two policies react to + /// these two compound operations as follows: + /// + /// | | Policy 1 | Policy 2 | + /// | -- | -------------- | -------------- | + /// | C1 | `set("a", 14)` | `set("a", 15)` | + /// | C2 | `set("a", 16)` | -- | + /// + /// For option (1) the outcome will be `"a" -> 16`, for option (2) the outcome would be `"a" -> + /// 15`. + #[test] + fn test_ordering_within_compound_requests() { + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(create_test_policy(vec![ + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 11, + }, + action: TestAction::ChangeRequests(vec![SendableChangeRequest::from_fn( + |backend| { + backend.set(String::from("a"), 12); + backend.set(String::from("a"), 13); + }, + )]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 12, + }, + action: TestAction::ChangeRequests(vec![SendableChangeRequest::set( + String::from("a"), + 14, + )]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 13, + }, + action: TestAction::ChangeRequests(vec![SendableChangeRequest::set( + String::from("a"), + 16, + )]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 14, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 15, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 16, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + ])); + backend.add_policy(create_test_policy(vec![ + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 11, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 12, + }, + action: TestAction::ChangeRequests(vec![SendableChangeRequest::set( + String::from("a"), + 15, + )]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 13, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 14, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 15, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Set { + k: String::from("a"), + v: 16, + }, + action: TestAction::ChangeRequests(vec![]), + }, + TestStep { + condition: TestBackendInteraction::Get { + k: String::from("a"), + }, + action: TestAction::ChangeRequests(vec![]), + }, + ])); + + backend.set(String::from("a"), 11); + + assert_eq!(backend.get(&String::from("a")), Some(16)); + } + + #[derive(Debug)] + struct DropTester(Arc<()>, T) + where + T: Debug + Send + 'static; + + impl CacheBackend for DropTester + where + T: Debug + Send + 'static, + { + type K = (); + type V = (); + + fn get(&mut self, _k: &Self::K) -> Option { + unreachable!() + } + + fn set(&mut self, _k: Self::K, _v: Self::V) { + unreachable!() + } + + fn remove(&mut self, _k: &Self::K) { + unreachable!() + } + + fn is_empty(&self) -> bool { + true + } + + fn as_any(&self) -> &dyn std::any::Any { + unreachable!() + } + } + + impl Subscriber for DropTester + where + T: Debug + Send + 'static, + { + type K = (); + type V = (); + } + + fn create_test_policy( + steps: Vec, + ) -> impl FnOnce(CallbackHandle) -> TestSubscriber { + |handle| TestSubscriber { + background_task: TestSubscriberBackgroundTask::NotStarted(handle), + steps: VecDeque::from(steps), + } + } + + #[derive(Debug, PartialEq)] + enum TestBackendInteraction { + Get { k: String }, + + Set { k: String, v: usize }, + + Remove { k: String }, + + Panic, + } + + impl TestBackendInteraction { + fn perform(self, handle: &mut CallbackHandle) { + match self { + Self::Get { k } => { + handle.execute_requests(vec![ChangeRequest::get(k)]); + } + Self::Set { k, v } => handle.execute_requests(vec![ChangeRequest::set(k, v)]), + Self::Remove { k } => handle.execute_requests(vec![ChangeRequest::remove(k)]), + Self::Panic => panic!("this is a test"), + } + } + } + + #[derive(Debug)] + enum TestAction { + /// Perform an illegal direct, recursive call to the backend. + CallBackendDirectly(TestBackendInteraction), + + /// Return change requests + ChangeRequests(Vec), + + /// Use callback backend but wait for a barrier in a background thread. + /// + /// This will return immediately. + CallBackendDelayed(Arc, TestBackendInteraction, Arc), + + /// Block on barrier and return afterwards. + BlockAndChangeRequest(Arc, Vec), + } + + impl TestAction { + fn perform( + self, + background_task: &mut TestSubscriberBackgroundTask, + ) -> Vec> { + match self { + Self::CallBackendDirectly(interaction) => { + let handle = match background_task { + TestSubscriberBackgroundTask::NotStarted(handle) => handle, + TestSubscriberBackgroundTask::Started(_) => { + panic!("background task already started") + } + TestSubscriberBackgroundTask::Invalid => panic!("Invalid state"), + }; + + interaction.perform(handle); + unreachable!("illegal call should have failed") + } + Self::ChangeRequests(change_requests) => { + change_requests.into_iter().map(|r| r.into()).collect() + } + Self::CallBackendDelayed(barrier_pre, interaction, barrier_post) => { + let mut tmp = TestSubscriberBackgroundTask::Invalid; + std::mem::swap(&mut tmp, background_task); + let mut handle = match tmp { + TestSubscriberBackgroundTask::NotStarted(handle) => handle, + TestSubscriberBackgroundTask::Started(_) => { + panic!("background task already started") + } + TestSubscriberBackgroundTask::Invalid => panic!("Invalid state"), + }; + + let join_handle = std::thread::spawn(move || { + barrier_pre.wait(); + interaction.perform(&mut handle); + barrier_post.wait(); + }); + *background_task = TestSubscriberBackgroundTask::Started(join_handle); + + vec![] + } + Self::BlockAndChangeRequest(barrier, change_requests) => { + barrier.wait(); + change_requests.into_iter().map(|r| r.into()).collect() + } + } + } + } + + #[derive(Debug)] + struct TestStep { + condition: TestBackendInteraction, + action: TestAction, + } + + #[derive(Debug)] + enum TestSubscriberBackgroundTask { + NotStarted(CallbackHandle), + Started(JoinHandle<()>), + + /// Temporary variant for swapping. + Invalid, + } + + #[derive(Debug)] + struct TestSubscriber { + background_task: TestSubscriberBackgroundTask, + steps: VecDeque, + } + + impl Drop for TestSubscriber { + fn drop(&mut self) { + // prevent SIGABRT due to double-panic + if !std::thread::panicking() { + assert!(self.steps.is_empty(), "test steps left"); + let mut tmp = TestSubscriberBackgroundTask::Invalid; + std::mem::swap(&mut tmp, &mut self.background_task); + + match tmp { + TestSubscriberBackgroundTask::NotStarted(_) => { + // nothing to check + } + TestSubscriberBackgroundTask::Started(handle) => { + // propagate panics + if let Err(e) = handle.join() { + if let Some(err) = e.downcast_ref::<&str>() { + panic!("Error in background task: {err}") + } else if let Some(err) = e.downcast_ref::() { + panic!("Error in background task: {err}") + } else { + panic!("Error in background task: ") + } + } + } + TestSubscriberBackgroundTask::Invalid => { + // that's OK during drop + } + } + } + } + } + + impl Subscriber for TestSubscriber { + type K = String; + type V = usize; + + fn get( + &mut self, + k: &Self::K, + _now: Time, + ) -> Vec> { + let step = self.steps.pop_front().expect("step left for get operation"); + + let expected_condition = TestBackendInteraction::Get { k: k.clone() }; + assert_eq!( + step.condition, expected_condition, + "Condition mismatch\n\nActual:\n{:#?}\n\nExpected:\n{:#?}", + step.condition, expected_condition, + ); + + step.action.perform(&mut self.background_task) + } + + fn set( + &mut self, + k: &Self::K, + v: &Self::V, + _now: Time, + ) -> Vec> { + let step = self.steps.pop_front().expect("step left for set operation"); + + let expected_condition = TestBackendInteraction::Set { + k: k.clone(), + v: *v, + }; + assert_eq!( + step.condition, expected_condition, + "Condition mismatch\n\nActual:\n{:#?}\n\nExpected:\n{:#?}", + step.condition, expected_condition, + ); + + step.action.perform(&mut self.background_task) + } + + fn remove( + &mut self, + k: &Self::K, + _now: Time, + ) -> Vec> { + let step = self + .steps + .pop_front() + .expect("step left for remove operation"); + + let expected_condition = TestBackendInteraction::Remove { k: k.clone() }; + assert_eq!( + step.condition, expected_condition, + "Condition mismatch\n\nActual:\n{:#?}\n\nExpected:\n{:#?}", + step.condition, expected_condition, + ); + + step.action.perform(&mut self.background_task) + } + } + + /// Same as [`ChangeRequestFn`] but implements `Send`. + type SendableChangeRequestFn = + Box FnOnce(&'b mut Recorder<'a, String, usize>) + Send + 'static>; + + /// Same as [`ChangeRequest`] but implements `Send`. + struct SendableChangeRequest { + fun: SendableChangeRequestFn, + } + + impl Debug for SendableChangeRequest { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SendableCacheRequest") + .finish_non_exhaustive() + } + } + + impl SendableChangeRequest { + fn from_fn(f: F) -> Self + where + F: for<'b, 'c> FnOnce(&'c mut Recorder<'b, String, usize>) + Send + 'static, + { + Self { fun: Box::new(f) } + } + + fn get(k: String) -> Self { + Self::from_fn(move |backend| { + backend.get(&k); + }) + } + + fn set(k: String, v: usize) -> Self { + Self::from_fn(move |backend| { + backend.set(k, v); + }) + } + + fn remove(k: String) -> Self { + Self::from_fn(move |backend| { + backend.remove(&k); + }) + } + } + + impl From for ChangeRequest<'static, String, usize> { + fn from(request: SendableChangeRequest) -> Self { + Self::from_fn(request.fun) + } + } +} diff --git a/cache_system/src/backend/policy/refresh.rs b/cache_system/src/backend/policy/refresh.rs new file mode 100644 index 0000000..6682352 --- /dev/null +++ b/cache_system/src/backend/policy/refresh.rs @@ -0,0 +1,1028 @@ +//! Refresh handling. +use std::{ + fmt::Debug, + hash::Hash, + marker::PhantomData, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, +}; + +use backoff::{Backoff, BackoffConfig}; +use futures::{future::BoxFuture, stream::FuturesUnordered, FutureExt, StreamExt}; +use iox_time::{Time, TimeProvider}; +use metric::U64Counter; +use parking_lot::Mutex; +use rand::rngs::mock::StepRng; +use tokio::{runtime::Handle, sync::Notify, task::JoinHandle}; +use tokio_util::sync::CancellationToken; + +use crate::{addressable_heap::AddressableHeap, loader::Loader}; + +use super::{CacheBackend, CallbackHandle, ChangeRequest, Subscriber}; + +/// Interface to provide refresh duration for a key-value pair. +pub trait RefreshDurationProvider: std::fmt::Debug + Send + Sync + 'static { + /// Cache key. + type K; + + /// Cached value. + type V; + + /// When should the given key-value pair be refreshed? + /// + /// Return `None` for "never". + /// + /// The function is only called once for a newly cached key-value pair. This means: + /// - There is no need in remembering the time of a given pair (e.g. you can safely always return a constant). + /// - You cannot change the timings after the data was cached. + /// + /// Refresh is set to take place AT OR AFTER the provided duration. + fn refresh_in(&self, k: &Self::K, v: &Self::V) -> Option; +} + +/// [`RefreshDurationProvider`] that never expires. +#[derive(Default)] +pub struct NeverRefreshProvider +where + K: 'static, + V: 'static, +{ + // phantom data that is Send and Sync, see https://stackoverflow.com/a/50201389 + _k: PhantomData K>, + _v: PhantomData V>, +} + +impl std::fmt::Debug for NeverRefreshProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NeverRefreshProvider") + .finish_non_exhaustive() + } +} + +impl RefreshDurationProvider for NeverRefreshProvider { + type K = K; + type V = V; + + fn refresh_in(&self, _k: &Self::K, _v: &Self::V) -> Option { + None + } +} + +/// [`RefreshDurationProvider`] that returns different values for `None`/`Some(...)` values. +pub struct OptionalValueRefreshDurationProvider +where + K: 'static, + V: 'static, +{ + // phantom data that is Send and Sync, see https://stackoverflow.com/a/50201389 + _k: PhantomData K>, + _v: PhantomData V>, + + backoff_cfg_none: Option, + backoff_cfg_some: Option, +} + +impl OptionalValueRefreshDurationProvider +where + K: 'static, + V: 'static, +{ + /// Create new provider with the given refresh duration for `None` and `Some(...)`. + pub fn new( + backoff_cfg_none: Option, + backoff_cfg_some: Option, + ) -> Self { + Self { + _k: PhantomData, + _v: PhantomData, + backoff_cfg_none, + backoff_cfg_some, + } + } +} + +impl std::fmt::Debug for OptionalValueRefreshDurationProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OptionalValueRefreshDurationProvider") + .field("t_none", &self.backoff_cfg_none) + .field("t_some", &self.backoff_cfg_some) + .finish_non_exhaustive() + } +} + +impl RefreshDurationProvider for OptionalValueRefreshDurationProvider { + type K = K; + type V = Option; + + fn refresh_in(&self, _k: &Self::K, v: &Self::V) -> Option { + match v { + None => self.backoff_cfg_none.clone(), + Some(_) => self.backoff_cfg_some.clone(), + } + } +} + +/// Tag for keys (incl. their backoff state and their running background tasks) to reason about lock gaps. +type Tag = u64; + +/// Cache policy that implements refreshing. +#[derive(Debug)] +pub struct RefreshPolicy +where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + refresh_duration_provider: Arc>, + background_worker: JoinHandle<()>, + timings: Arc>>, + timings_changed: Arc, + tag_counter: AtomicU64, + rng_overwrite: Option, +} + +impl RefreshPolicy +where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + /// Create new refresh policy. + #[allow(clippy::new_ret_no_self)] + pub fn new( + time_provider: Arc, + refresh_duration_provider: Arc>, + loader: Arc>, + name: &'static str, + metric_registry: &metric::Registry, + handle: &Handle, + ) -> impl FnOnce(CallbackHandle) -> Self { + let idle_notify = Arc::new(Notify::new()); + Self::new_inner( + time_provider, + refresh_duration_provider, + loader, + name, + metric_registry, + idle_notify, + handle, + None, + ) + } + + /// Create new refresh policy but allows to specify some internals for testing. + /// + /// These internals are: + /// + /// - `idle_notify`: a [`Notify`] that will be triggered when the background worker is idle. + /// - `rng_overwrite`: a static RNG that will be used for the [`backoff`]-based refresh timers instead of a true + /// thread RNG. + #[allow(clippy::new_ret_no_self, clippy::too_many_arguments)] + pub(crate) fn new_inner( + time_provider: Arc, + refresh_duration_provider: Arc>, + loader: Arc>, + name: &'static str, + metric_registry: &metric::Registry, + idle_notify: Arc, + handle: &Handle, + rng_overwrite: Option, + ) -> impl FnOnce(CallbackHandle) -> Self { + let metric_refreshed = metric_registry + .register_metric::("cache_refresh", "Number of cache refresh operations.") + .recorder(&[("name", name)]); + + // clone handle for callback + let handle = handle.clone(); + + move |mut callback_handle| { + callback_handle.execute_requests(vec![ChangeRequest::ensure_empty()]); + + let timings: Arc>> = + Default::default(); + let timings_captured = Arc::clone(&timings); + let timings_changed = Arc::new(Notify::new()); + let timings_changed_captured = Arc::clone(&timings_changed); + let callback_handle = Arc::new(Mutex::new(callback_handle)); + let rng_overwrite_captured = rng_overwrite.clone(); + + let background_worker = handle.spawn(async move { + let mut refresh_tasks = FuturesUnordered::>>::new(); + + // We MUST NOT poll the empty task set because this would finish immediately. This will hot-loop + // the loop. Even worse, since `FuturesUnodered` is not hooked up into tokio's (somewhat bizarre) + // task preemtion system, tokio will poll this method here forever, essentially blocking this + // thread. + refresh_tasks.push(Box::pin(futures::future::pending())); + + // flag that remembers if we can notify idle observers again + let mut can_notify_idle = true; + + loop { + // future that waits for the next refresh task to start + let fut_start_next_task: BoxFuture<'static, ()> = { + let timings = timings_captured.lock(); + match timings.peek() { + None => Box::pin(futures::future::pending()), + Some((_k, _state, t_next)) => match t_next { + TimeOrNever::Never => Box::pin(futures::future::pending()), + TimeOrNever::Time(t) => Box::pin(time_provider.sleep_until(*t)), + } + } + }; + + // future that "guards" our idle notification to prevent hot loops (essentially blocking the entire + // tokio thread forever) + let fut_idle_notify_guard: BoxFuture<'static, ()> = if can_notify_idle { + Box::pin(futures::future::ready(())) + } else { + Box::pin(futures::future::pending()) + }; + + tokio::select! { + biased; + maybe_k_and_tag = refresh_tasks.next() => { + // a refresh tasks finished + + // see if this refresh task was NOT finished + if let Some((k, tag)) = maybe_k_and_tag.flatten() { + let mut timings = timings_captured.lock(); + if let Some((mut state, t_next)) = timings.remove(&k) { + if state.tag == tag { + state.running_refresh = None; + let (state, t_next) = state.next(time_provider.now(), &rng_overwrite_captured); + timings.insert(k, state, t_next); + } else { + // wrong one (lock gap) + timings.insert(k, state, t_next); + } + } + } + + can_notify_idle = true; + } + _ = fut_start_next_task => { + // a new refresh task shall start + let mut timings = timings_captured.lock(); + + // careful with inspection of timings since there was a lock-gap, the data might have changed + if let Some((k, mut state, t_next)) = timings.pop() { + if t_next <= TimeOrNever::Time(time_provider.now()) { + assert!(state.running_refresh.is_none()); + + let (fut, ctoken) = Self::refresh(Arc::clone(&loader), Arc::clone(&callback_handle), k.clone(), state.tag, metric_refreshed.clone()); + state.running_refresh = Some(ctoken); + refresh_tasks.push(fut); + + timings.insert(k, state, TimeOrNever::Never); + } else { + // the entry in question is gone and we got the wrong one, put it back + timings.insert(k, state, t_next); + } + } + + can_notify_idle = true; + } + _ = timings_changed_captured.notified() => { + // timings updated + + // do NOT count this as "can not notify IDLE" because nothing really happened yet + } + _ = fut_idle_notify_guard => { + // no other jobs to do (this select is biased!), we inform the external test observer + idle_notify.notify_one(); + can_notify_idle = false; + } + } + } + }); + + Self { + refresh_duration_provider, + background_worker, + timings, + timings_changed, + tag_counter: AtomicU64::new(0), + rng_overwrite, + } + } + } + + /// Start refresh task for given key and return cancelation token for the task. + /// + /// You shall store the given token in [`RefreshState`]. + #[must_use] + fn refresh( + loader: Arc>, + callback_handle: Arc>>, + k: K, + tag: Tag, + metric_refreshed: U64Counter, + ) -> (BoxFuture<'static, Option<(K, Tag)>>, CancellationToken) { + let cancelled = CancellationToken::default(); + + let cancelled_captured = cancelled.clone(); + let fut = async move { + // some `let`-dance so that rustc does not complain that `&K` is not `Send` + let k_for_loader = k.clone(); + let v = loader.load(k_for_loader, ()).await; + + let mut callback_handle = callback_handle.lock(); + callback_handle.execute_requests(vec![ChangeRequest::from_fn(|backend| { + // Here we have the PolicyBackend implicit lock. There is no way our Subscriber can be + // active here, but we need to check if we have been canceled one last time. + if cancelled_captured.is_cancelled() { + return; + } + + backend.set(k.clone(), v); + })]); + + // update metric AFTER change request + metric_refreshed.inc(1); + + // there is NO need to update our own `timings` after this refresh because this very Subscriber + // will also get a `set` notification and update its timing table accordingly + (k, tag) + }; + + let cancelled_captured = cancelled.clone(); + let fut = async move { + tokio::select! { + _ = cancelled_captured.cancelled() => None, + k = fut => Some(k), + } + } + .boxed(); + + (fut, cancelled) + } +} + +impl Drop for RefreshPolicy +where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + fn drop(&mut self) { + self.background_worker.abort(); + } +} + +impl Subscriber for RefreshPolicy +where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + type K = K; + type V = V; + + fn get(&mut self, k: &Self::K, now: Time) -> Vec> { + let mut timings = self.timings.lock(); + + // Does this entry exists? + if let Some((mut state, t_next)) = timings.remove(k) { + // reset backoff + state.next = None; + + if state.running_refresh.is_some() { + // there is a refresh operation running, so just reset the backoff and put this back + assert_eq!(t_next, TimeOrNever::Never); + timings.insert(k.clone(), state, TimeOrNever::Never); + } else { + // refresh operation currently NOT running => schedule one + let (state, t_next) = state.next(now, &self.rng_overwrite); + timings.insert(k.clone(), state, t_next); + self.timings_changed.notify_one(); + } + } + + vec![] + } + + fn set( + &mut self, + k: &Self::K, + v: &Self::V, + now: Time, + ) -> Vec> { + let backoff_cfg = self.refresh_duration_provider.refresh_in(k, v); + + let mut timings = self.timings.lock(); + + // ignore any entries that don't require any work + if let Some(backoff_cfg) = backoff_cfg { + if let Some((mut state, time)) = timings.remove(k) { + // we know this key already + state.next = match state.next.take() { + Some(mut next) => { + next.fade_to(&backoff_cfg); + Some(next) + } + None => None, + }; + state.backoff_cfg = backoff_cfg; + + timings.insert(k.clone(), state, time); + self.timings_changed.notify_one(); + } else { + // new key + let state = + RefreshState::new(backoff_cfg, self.tag_counter.fetch_add(1, Ordering::SeqCst)); + let (state, time) = state.next(now, &self.rng_overwrite); + + timings.insert(k.clone(), state, time); + self.timings_changed.notify_one(); + } + } else { + // need to remove potentially existing entry that had some refresh set + timings.remove(k); + + // the removal drops the RefreshState which triggers a cancelation for any potentially running + // refresh operation + } + + vec![] + } + + fn remove(&mut self, k: &Self::K, _now: Time) -> Vec> { + let mut timings = self.timings.lock(); + timings.remove(k); + + // the removal automatically triggered a cancelation for any potentially running refresh operation + + vec![] + } +} + +/// Current state of an entry managed by the refresh policy. +#[derive(Debug)] +struct RefreshState { + /// When to refresh or expire. + backoff_cfg: BackoffConfig, + + /// Current backoff state + next: Option, + + /// Tag that links the background task to this very entry + tag: Tag, + + /// Cancellation token for a potentially running refresh operation. + /// + /// This token will be triggered on [`drop`](Drop::drop). + running_refresh: Option, +} + +impl RefreshState { + fn new(backoff_cfg: BackoffConfig, tag: Tag) -> Self { + Self { + backoff_cfg, + next: None, + tag, + running_refresh: None, + } + } + + fn next(mut self, now: Time, rng_overwrite: &Option) -> (Self, TimeOrNever) { + assert!(self.running_refresh.is_none()); + + let mut next = self.next.take().unwrap_or_else(|| { + Backoff::new_with_rng( + &self.backoff_cfg, + rng_overwrite.as_ref().map(|rng| Box::new(rng.clone()) as _), + ) + }); + let time = match next.next().and_then(|d| now.checked_add(d)) { + None => TimeOrNever::Never, + Some(time) => TimeOrNever::Time(time), + }; + let this = Self { + backoff_cfg: self.backoff_cfg.clone(), + tag: self.tag, + next: Some(next), + running_refresh: None, + }; + (this, time) + } +} + +impl Drop for RefreshState { + fn drop(&mut self) { + if let Some(token) = &self.running_refresh { + token.cancel(); + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +enum TimeOrNever { + Time(Time), + Never, +} + +pub mod test_util { + //! Testing utilities for refresh policy. + + use std::{collections::HashMap, time::Duration}; + + use super::*; + + /// Easy-to-control [`RefreshDurationProvider`]. + #[derive(Debug, Default)] + pub struct TestRefreshDurationProvider { + times: Mutex>>, + } + + impl TestRefreshDurationProvider { + /// Create new, empty provider. + pub fn new() -> Self { + Self::default() + } + + /// Specify a refresh duration for a given key-value pair. + /// + /// Existing values will be overridden. + pub fn set_refresh_in(&self, k: u8, v: String, d: Option) { + self.times.lock().insert((k, v), d); + // do NOT check if there was already a value set because we allow overrides + } + } + + impl RefreshDurationProvider for TestRefreshDurationProvider { + type K = u8; + type V = String; + + fn refresh_in(&self, k: &Self::K, v: &Self::V) -> Option { + self.times + .lock() + .get(&(*k, v.clone())) + .unwrap_or_else(|| panic!("refresh time not mocked: K={k}, V={v}")) + .clone() + } + } + + /// Some extensions for [`Notify`]. + pub trait NotifyExt { + /// Wait for notification but panic after a short timeout. + fn notified_with_timeout(&self) -> BoxFuture<'_, ()>; + + /// Ensure that we are NOT notified. + fn not_notified(&self) -> BoxFuture<'_, ()>; + } + + impl NotifyExt for Notify { + fn notified_with_timeout(&self) -> BoxFuture<'_, ()> { + Box::pin(async { + tokio::time::timeout(Duration::from_secs(1), self.notified()) + .await + .expect("notified_with_timeout"); + }) + } + + fn not_notified(&self) -> BoxFuture<'_, ()> { + Box::pin(async { + tokio::time::timeout(Duration::from_millis(10), self.notified()) + .await + .unwrap_err(); + }) + } + } + + /// Generate a simple [`BackoffConfig`] for testing. + /// + /// Uses the given duration as initial backoff and a base of 2. No max backoff and deadline are set. + pub fn backoff_cfg(d: Duration) -> BackoffConfig { + BackoffConfig { + init_backoff: d, + max_backoff: Duration::MAX, + base: 2.0, + deadline: None, + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + #[should_panic(expected = "refresh time not mocked: K=1, V=foo")] + fn test_provider_panic_not_mocked() { + let provider = TestRefreshDurationProvider::default(); + provider.refresh_in(&1, &String::from("foo")); + } + + #[test] + fn test_provider_mocking() { + let provider = TestRefreshDurationProvider::default(); + + let cfg1 = BackoffConfig::default(); + let cfg2 = BackoffConfig { base: 42., ..cfg1 }; + let cfg3 = BackoffConfig { + base: 1337., + ..cfg1 + }; + + provider.set_refresh_in(1, String::from("a"), None); + provider.set_refresh_in(1, String::from("b"), Some(cfg1.clone())); + provider.set_refresh_in(2, String::from("a"), Some(cfg2.clone())); + + assert_eq!(provider.refresh_in(&1, &String::from("a")), None); + assert_eq!(provider.refresh_in(&1, &String::from("b")), Some(cfg1),); + assert_eq!(provider.refresh_in(&2, &String::from("a")), Some(cfg2),); + + // replace + provider.set_refresh_in(1, String::from("a"), Some(cfg3.clone())); + assert_eq!(provider.refresh_in(&1, &String::from("a")), Some(cfg3),); + } + } +} + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, time::Duration}; + + use iox_time::MockProvider; + use metric::{Observation, RawReporter}; + use rand::rngs::mock::StepRng; + + use crate::{ + backend::{ + policy::{ + refresh::test_util::{backoff_cfg, NotifyExt}, + PolicyBackend, + }, + CacheBackend, + }, + loader::test_util::TestLoader, + }; + + use super::{test_util::TestRefreshDurationProvider, *}; + + #[test] + fn test_time_or_never_ord() { + assert!(TimeOrNever::Never == TimeOrNever::Never); + assert!( + TimeOrNever::Time(Time::from_timestamp_millis(1).unwrap()) + == TimeOrNever::Time(Time::from_timestamp_millis(1).unwrap()) + ); + assert!( + TimeOrNever::Time(Time::from_timestamp_millis(1).unwrap()) + < TimeOrNever::Time(Time::from_timestamp_millis(2).unwrap()) + ); + assert!(TimeOrNever::Time(Time::from_timestamp_millis(1).unwrap()) < TimeOrNever::Never); + } + + #[test] + fn test_never_refresh_provider() { + let provider = NeverRefreshProvider::::default(); + assert_eq!(provider.refresh_in(&1, &2), None); + } + + #[test] + fn test_optional_value_ttl_provider() { + let t_none = Some(BackoffConfig { + base: 1., + ..Default::default() + }); + let t_some = Some(BackoffConfig { + base: 2., + ..Default::default() + }); + let provider = + OptionalValueRefreshDurationProvider::::new(t_none.clone(), t_some.clone()); + assert_eq!(provider.refresh_in(&1, &None), t_none); + assert_eq!(provider.refresh_in(&1, &Some(2)), t_some); + } + + #[tokio::test] + #[should_panic(expected = "inner backend is not empty")] + async fn test_panic_inner_not_empty() { + let refresh_duration_provider = Arc::new(TestRefreshDurationProvider::new()); + let metric_registry = metric::Registry::new(); + + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let loader = Arc::new(TestLoader::default()); + let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + let policy_constructor = RefreshPolicy::new( + time_provider, + refresh_duration_provider, + loader, + "my_cache", + &metric_registry, + &Handle::current(), + ); + backend.add_policy(|mut handle| { + handle.execute_requests(vec![ChangeRequest::set(1, String::from("foo"))]); + policy_constructor(handle) + }); + } + + #[tokio::test] + async fn test_duration_overflow() { + let refresh_duration_provider = Arc::new(TestRefreshDurationProvider::new()); + refresh_duration_provider.set_refresh_in( + 1, + String::from("a"), + Some(BackoffConfig { + init_backoff: Duration::MAX, + ..Default::default() + }), + ); + + let metric_registry = metric::Registry::new(); + let time_provider = Arc::new(MockProvider::new(Time::MAX - Duration::from_secs(1))); + let loader = Arc::new(TestLoader::default()); + let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend.add_policy(RefreshPolicy::new( + Arc::clone(&time_provider) as _, + refresh_duration_provider, + loader, + "my_cache", + &metric_registry, + &Handle::current(), + )); + + backend.set(1, String::from("a")); + + time_provider.inc(Duration::from_secs(1)); + assert_eq!(backend.get(&1), Some(String::from("a"))); + assert_eq!(get_refresh_metric(&metric_registry), 0); + } + + #[tokio::test] + async fn test_refresh() { + let TestState { + mut backend, + refresh_duration_provider, + time_provider, + loader, + metric_registry, + notify_idle, + .. + } = TestState::new(); + + loader.mock_next(1, String::from("foo")); + loader.mock_next(1, String::from("bar")); + + refresh_duration_provider.set_refresh_in( + 1, + String::from("a"), + Some(backoff_cfg(Duration::from_secs(1))), + ); + refresh_duration_provider.set_refresh_in( + 1, + String::from("foo"), + Some(backoff_cfg(Duration::from_secs(1))), + ); + refresh_duration_provider.set_refresh_in(1, String::from("bar"), None); + + // start backoff cycle + backend.set(1, String::from("a")); + + // initial notify by the background loop + notify_idle.notified_with_timeout().await; + + // still the same key + assert_eq!(get_inner(&mut backend, 1), Some(String::from("a"))); + assert_eq!(get_refresh_metric(&metric_registry), 0); + + // refresh starts by background timer + time_provider.inc(Duration::from_secs(1)); + notify_idle.notified_with_timeout().await; + assert_eq!(get_refresh_metric(&metric_registry), 1); + assert_eq!(get_inner(&mut backend, 1), Some(String::from("foo"))); + + // nothing to refresh yet + notify_idle.not_notified().await; + assert_eq!(get_refresh_metric(&metric_registry), 1); + assert_eq!(get_inner(&mut backend, 1), Some(String::from("foo"))); + + // just bumping the refresh by the old refresh timer won't do anything (we need 2 seconds this time due to the + // base factor) + time_provider.inc(Duration::from_secs(1)); + notify_idle.not_notified().await; + assert_eq!(get_refresh_metric(&metric_registry), 1); + assert_eq!(get_inner(&mut backend, 1), Some(String::from("foo"))); + + // try a 2nd update + time_provider.inc(Duration::from_secs(1)); + notify_idle.notified_with_timeout().await; + assert_eq!(get_refresh_metric(&metric_registry), 2); + assert_eq!(get_inner(&mut backend, 1), Some(String::from("bar"))); + } + + #[tokio::test] + async fn test_do_not_start_refresh_while_one_is_running() { + let TestState { + mut backend, + refresh_duration_provider, + time_provider, + loader, + notify_idle, + .. + } = TestState::new(); + + let barrier = loader.block_next(1, String::from("foo")); + refresh_duration_provider.set_refresh_in( + 1, + String::from("a"), + Some(backoff_cfg(Duration::from_secs(1))), + ); + refresh_duration_provider.set_refresh_in(1, String::from("foo"), None); + backend.set(1, String::from("a")); + + time_provider.inc(Duration::from_secs(1)); + notify_idle.notified_with_timeout().await; + + // if this would start another refresh then the loader would panic because we've only mocked a single request + time_provider.inc(Duration::from_secs(100)); + notify_idle.not_notified().await; + + barrier.wait().await; + notify_idle.notified_with_timeout().await; + assert_eq!(backend.get(&1), Some(String::from("foo"))); + } + + #[tokio::test] + async fn test_refresh_does_not_override_new_entries() { + let TestState { + mut backend, + refresh_duration_provider, + time_provider, + loader, + notify_idle, + .. + } = TestState::new(); + + let barrier = loader.block_next(1, String::from("foo")); + refresh_duration_provider.set_refresh_in( + 1, + String::from("a"), + Some(backoff_cfg(Duration::from_secs(1))), + ); + refresh_duration_provider.set_refresh_in(1, String::from("b"), None); + backend.set(1, String::from("a")); + + // perform refresh + time_provider.inc(Duration::from_secs(1)); + notify_idle.notified_with_timeout().await; + + backend.set(1, String::from("b")); + barrier.wait().await; + notify_idle.notified_with_timeout().await; + assert_eq!(backend.get(&1), Some(String::from("b"))); + } + + #[tokio::test] + async fn test_remove_cancels_loader() { + let TestState { + mut backend, + refresh_duration_provider, + time_provider, + loader, + notify_idle, + .. + } = TestState::new(); + + let barrier = loader.block_next(1, String::from("foo")); + refresh_duration_provider.set_refresh_in( + 1, + String::from("a"), + Some(backoff_cfg(Duration::from_secs(1))), + ); + backend.set(1, String::from("a")); + + // perform refresh + time_provider.inc(Duration::from_secs(1)); + notify_idle.notified_with_timeout().await; + + assert_eq!(Arc::strong_count(&barrier), 2); + backend.remove(&1); + notify_idle.notified_with_timeout().await; + assert_eq!(Arc::strong_count(&barrier), 1); + } + + #[tokio::test] + async fn test_override_with_no_refresh() { + let TestState { + mut backend, + refresh_duration_provider, + time_provider, + loader, + notify_idle, + .. + } = TestState::new(); + + let barrier = loader.block_next(1, String::from("foo")); + refresh_duration_provider.set_refresh_in( + 1, + String::from("a"), + Some(backoff_cfg(Duration::from_secs(1))), + ); + refresh_duration_provider.set_refresh_in(1, String::from("b"), None); + backend.set(1, String::from("a")); + + // perform refresh + time_provider.inc(Duration::from_secs(1)); + notify_idle.notified_with_timeout().await; + + backend.set(1, String::from("b")); + barrier.wait().await; + + // no refresh + time_provider.inc(Duration::from_secs(1)); + notify_idle.notified_with_timeout().await; + assert_eq!(backend.get(&1), Some(String::from("b"))); + } + + #[tokio::test] + async fn test_generic_backend() { + use crate::backend::test_util::test_generic; + + test_generic(|| { + let refresh_duration_provider = Arc::new(NeverRefreshProvider::default()); + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let metric_registry = metric::Registry::new(); + let loader = Arc::new(TestLoader::default()); + let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + + backend.add_policy(RefreshPolicy::new( + time_provider, + Arc::clone(&refresh_duration_provider) as _, + loader, + "my_cache", + &metric_registry, + &Handle::current(), + )); + backend + }); + } + + struct TestState { + backend: PolicyBackend, + metric_registry: metric::Registry, + refresh_duration_provider: Arc, + time_provider: Arc, + loader: Arc>, + notify_idle: Arc, + } + + impl TestState { + fn new() -> Self { + let refresh_duration_provider = Arc::new(TestRefreshDurationProvider::new()); + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let metric_registry = metric::Registry::new(); + let loader = Arc::new(TestLoader::default()); + let notify_idle = Arc::new(Notify::new()); + + // set up "RNG" that always generates the maximum, so we can test things easier + let rng_overwrite = StepRng::new(u64::MAX, 0); + + let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend.add_policy(RefreshPolicy::new_inner( + Arc::clone(&time_provider) as _, + Arc::clone(&refresh_duration_provider) as _, + Arc::clone(&loader) as _, + "my_cache", + &metric_registry, + Arc::clone(¬ify_idle), + &Handle::current(), + Some(rng_overwrite), + )); + + Self { + backend, + metric_registry, + refresh_duration_provider, + time_provider, + loader, + notify_idle, + } + } + } + + fn get_inner(backend: &mut PolicyBackend, k: u8) -> Option { + let inner_backend = backend.inner_ref(); + let inner_backend = inner_backend + .as_any() + .downcast_ref::>() + .unwrap(); + inner_backend.get(&k).cloned() + } + + fn get_refresh_metric(metric_registry: &metric::Registry) -> u64 { + let mut reporter = RawReporter::default(); + metric_registry.report(&mut reporter); + let observation = reporter + .metric("cache_refresh") + .unwrap() + .observation(&[("name", "my_cache")]) + .unwrap(); + + if let Observation::U64Counter(c) = observation { + *c + } else { + panic!("Wrong observation type") + } + } +} diff --git a/cache_system/src/backend/policy/remove_if.rs b/cache_system/src/backend/policy/remove_if.rs new file mode 100644 index 0000000..57abf45 --- /dev/null +++ b/cache_system/src/backend/policy/remove_if.rs @@ -0,0 +1,288 @@ +//! Backend that supports custom removal / expiry of keys +use metric::U64Counter; +use parking_lot::Mutex; +use std::{fmt::Debug, hash::Hash, marker::PhantomData, sync::Arc}; + +use crate::{ + backend::policy::{CacheBackend, CallbackHandle, ChangeRequest, Subscriber}, + cache::{Cache, CacheGetStatus}, +}; + +/// Allows explicitly removing entries from the cache. +#[derive(Debug, Clone)] +pub struct RemoveIfPolicy +where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + // the policy itself doesn't do anything, the handles will do all the work + _phantom: PhantomData<(K, V)>, +} + +impl RemoveIfPolicy +where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + /// Create new policy. + /// + /// This returns the policy constructor which shall be pass to + /// [`PolicyBackend::add_policy`] and handle that can be used to remove entries. + /// + /// Note that as long as the policy constructor is NOT passed to [`PolicyBackend::add_policy`], the operations on + /// the handle are essentially no-ops (i.e. they will not remove anything). + /// + /// [`PolicyBackend::add_policy`]: super::PolicyBackend::add_policy + pub fn create_constructor_and_handle( + name: &'static str, + metric_registry: &metric::Registry, + ) -> ( + impl FnOnce(CallbackHandle) -> Self, + RemoveIfHandle, + ) { + let metric_removed_by_predicate = metric_registry + .register_metric::( + "cache_removed_by_custom_condition", + "Number of entries removed from a cache via a custom condition", + ) + .recorder(&[("name", name)]); + + let handle = RemoveIfHandle { + callback_handle: Arc::new(Mutex::new(None)), + metric_removed_by_predicate, + }; + let handle_captured = handle.clone(); + + let policy_constructor = move |callback_handle| { + *handle_captured.callback_handle.lock() = Some(callback_handle); + Self { + _phantom: PhantomData, + } + }; + + (policy_constructor, handle) + } +} + +impl Subscriber for RemoveIfPolicy +where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + type K = K; + type V = V; +} + +/// Handle created by [`RemoveIfPolicy`] that can be used to evict data from caches. +/// +/// The handle can be cloned freely. All clones will refer to the same underlying backend. +#[derive(Debug, Clone)] +pub struct RemoveIfHandle +where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + callback_handle: Arc>>>, + metric_removed_by_predicate: U64Counter, +} + +impl RemoveIfHandle +where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + /// "remove" a key (aka remove it from the shared backend) if the + /// specified predicate is true. If the key is removed return + /// true, otherwise return false + /// + /// Note that the predicate function is called while the lock is + /// held (and thus the inner backend can't be concurrently accessed + pub fn remove_if

(&self, k: &K, predicate: P) -> bool + where + P: FnOnce(V) -> bool, + { + let mut guard = self.callback_handle.lock(); + let handle = match guard.as_mut() { + Some(handle) => handle, + None => return false, + }; + + let metric_removed_by_predicate = self.metric_removed_by_predicate.clone(); + let mut removed = false; + let removed_captured = &mut removed; + let k = k.clone(); + handle.execute_requests(vec![ChangeRequest::from_fn(move |backend| { + if let Some(v) = backend.get_untracked(&k) { + if predicate(v) { + metric_removed_by_predicate.inc(1); + backend.remove(&k); + *removed_captured = true; + } + } + })]); + + removed + } + + /// Performs [`remove_if`](Self::remove_if) and [`GET`](Cache::get) in one go. + /// + /// Ensures that these two actions interact correctly. + /// + /// # Forward process + /// This function only works if cache values evolve in one direction. This is that the predicate can only flip from + /// `true` to `false` over time (i.e. it detects an outdated value and then an up-to-date value), NOT the other way + /// around (i.e. data cannot get outdated under the same predicate). + pub async fn remove_if_and_get_with_status( + &self, + cache: &C, + k: K, + predicate: P, + extra: GetExtra, + ) -> (V, CacheGetStatus) + where + P: Fn(V) -> bool + Send, + C: Cache, + GetExtra: Clone + Send, + { + let mut removed = self.remove_if(&k, &predicate); + + loop { + // avoid some `Sync` bounds + let k_for_get = k.clone(); + let extra_for_get = extra.clone(); + let (v, status) = cache.get_with_status(k_for_get, extra_for_get).await; + + match status { + CacheGetStatus::Hit => { + // key existed and no other process loaded it => safe to use + return (v, status); + } + CacheGetStatus::Miss => { + // key didn't exist and we loaded it => safe to use + return (v, status); + } + CacheGetStatus::MissAlreadyLoading => { + if removed { + // key was outdated but there was some loading in process, this may have overlapped with our check + // so our check might have been incomplete => need to re-check + removed = self.remove_if(&k, &predicate); + if removed { + // removed again, so cannot use our result + continue; + } else { + // didn't remove => safe to use + return (v, status); + } + } else { + // there was a load action in process but the key was already up-to-date, so it's OK to use the new + // data as well (forward process) + return (v, status); + } + } + } + } + } + + /// Same as [`remove_if_and_get_with_status`](Self::remove_if_and_get_with_status) but without the status. + pub async fn remove_if_and_get( + &self, + cache: &C, + k: K, + predicate: P, + extra: GetExtra, + ) -> V + where + P: Fn(V) -> bool + Send, + C: Cache, + GetExtra: Clone + Send, + { + self.remove_if_and_get_with_status(cache, k, predicate, extra) + .await + .0 + } +} + +#[cfg(test)] +mod tests { + use iox_time::{MockProvider, Time}; + use metric::{Observation, RawReporter}; + + use crate::backend::{policy::PolicyBackend, CacheBackend}; + + use super::*; + + #[test] + fn test_generic_backend() { + use crate::backend::test_util::test_generic; + + test_generic(|| { + let metric_registry = metric::Registry::new(); + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + let (policy_constructor, _handle) = + RemoveIfPolicy::create_constructor_and_handle("my_cache", &metric_registry); + backend.add_policy(policy_constructor); + backend + }); + } + + #[test] + fn test_remove_if() { + let metric_registry = metric::Registry::new(); + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend: PolicyBackend = PolicyBackend::hashmap_backed(time_provider); + let (policy_constructor, handle) = + RemoveIfPolicy::create_constructor_and_handle("my_cache", &metric_registry); + backend.add_policy(policy_constructor); + backend.set(1, "foo".into()); + backend.set(2, "bar".into()); + + assert_eq!(get_removed_metric(&metric_registry), 0); + + assert!(!handle.remove_if(&1, |v| v == "zzz")); + assert_eq!(backend.get(&1), Some("foo".into())); + assert_eq!(backend.get(&2), Some("bar".into())); + assert_eq!(get_removed_metric(&metric_registry), 0); + + assert!(handle.remove_if(&1, |v| v == "foo")); + assert_eq!(backend.get(&1), None); + assert_eq!(backend.get(&2), Some("bar".into())); + assert_eq!(get_removed_metric(&metric_registry), 1); + + assert!(!handle.remove_if(&1, |v| v == "bar")); + assert_eq!(backend.get(&1), None); + assert_eq!(backend.get(&2), Some("bar".into())); + assert_eq!(get_removed_metric(&metric_registry), 1); + } + + #[test] + fn test_not_linked() { + let metric_registry = metric::Registry::new(); + let (_policy_constructor, handle) = + RemoveIfPolicy::::create_constructor_and_handle( + "my_cache", + &metric_registry, + ); + + assert_eq!(get_removed_metric(&metric_registry), 0); + + assert!(!handle.remove_if(&1, |v| v == "zzz")); + assert_eq!(get_removed_metric(&metric_registry), 0); + } + + fn get_removed_metric(metric_registry: &metric::Registry) -> u64 { + let mut reporter = RawReporter::default(); + metric_registry.report(&mut reporter); + let observation = reporter + .metric("cache_removed_by_custom_condition") + .unwrap() + .observation(&[("name", "my_cache")]) + .unwrap(); + + if let Observation::U64Counter(c) = observation { + *c + } else { + panic!("Wrong observation type") + } + } +} diff --git a/cache_system/src/backend/policy/ttl.rs b/cache_system/src/backend/policy/ttl.rs new file mode 100644 index 0000000..fee9e62 --- /dev/null +++ b/cache_system/src/backend/policy/ttl.rs @@ -0,0 +1,755 @@ +//! Time-to-live handling. +use std::{fmt::Debug, hash::Hash, marker::PhantomData, sync::Arc, time::Duration}; + +use iox_time::Time; +use metric::U64Counter; + +use crate::addressable_heap::AddressableHeap; + +use super::{CallbackHandle, ChangeRequest, Subscriber}; + +/// Interface to provide TTL (time to live) data for a key-value pair. +pub trait TtlProvider: std::fmt::Debug + Send + Sync + 'static { + /// Cache key. + type K; + + /// Cached value. + type V; + + /// When should the given key-value pair expire? + /// + /// Return `None` for "never". + /// + /// The function is only called once for a newly cached key-value pair. This means: + /// - There is no need in remembering the time of a given pair (e.g. you can safely always return a constant). + /// - You cannot change the TTL after the data was cached. + /// + /// Expiration is set to take place AT OR AFTER the provided duration. + fn expires_in(&self, k: &Self::K, v: &Self::V) -> Option; +} + +/// [`TtlProvider`] that never expires. +#[derive(Default)] +pub struct NeverTtlProvider +where + K: 'static, + V: 'static, +{ + // phantom data that is Send and Sync, see https://stackoverflow.com/a/50201389 + _k: PhantomData K>, + _v: PhantomData V>, +} + +impl std::fmt::Debug for NeverTtlProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NeverTtlProvider").finish_non_exhaustive() + } +} + +impl TtlProvider for NeverTtlProvider { + type K = K; + type V = V; + + fn expires_in(&self, _k: &Self::K, _v: &Self::V) -> Option { + None + } +} + +/// [`TtlProvider`] that returns a constant value. +pub struct ConstantValueTtlProvider +where + K: 'static, + V: 'static, +{ + // phantom data that is Send and Sync, see https://stackoverflow.com/a/50201389 + _k: PhantomData K>, + _v: PhantomData V>, + + ttl: Option, +} + +impl ConstantValueTtlProvider +where + K: 'static, + V: 'static, +{ + /// Create new provider with the given TTL value. + pub fn new(ttl: Option) -> Self { + Self { + _k: PhantomData, + _v: PhantomData, + ttl, + } + } +} + +impl std::fmt::Debug for ConstantValueTtlProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ConstantValueTtlProvider") + .field("ttl", &self.ttl) + .finish_non_exhaustive() + } +} + +impl TtlProvider for ConstantValueTtlProvider { + type K = K; + type V = V; + + fn expires_in(&self, _k: &Self::K, _v: &Self::V) -> Option { + self.ttl + } +} + +/// [`TtlProvider`] that returns different values for `None`/`Some(...)` values. +pub struct OptionalValueTtlProvider +where + K: 'static, + V: 'static, +{ + // phantom data that is Send and Sync, see https://stackoverflow.com/a/50201389 + _k: PhantomData K>, + _v: PhantomData V>, + + ttl_none: Option, + ttl_some: Option, +} + +impl OptionalValueTtlProvider +where + K: 'static, + V: 'static, +{ + /// Create new provider with the given TTL values for `None` and `Some(...)`. + pub fn new(ttl_none: Option, ttl_some: Option) -> Self { + Self { + _k: PhantomData, + _v: PhantomData, + ttl_none, + ttl_some, + } + } +} + +impl std::fmt::Debug for OptionalValueTtlProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OptionalValueTtlProvider") + .field("ttl_none", &self.ttl_none) + .field("ttl_some", &self.ttl_some) + .finish_non_exhaustive() + } +} + +impl TtlProvider for OptionalValueTtlProvider { + type K = K; + type V = Option; + + fn expires_in(&self, _k: &Self::K, v: &Self::V) -> Option { + match v { + None => self.ttl_none, + Some(_) => self.ttl_some, + } + } +} + +/// Cache policy that implements Time To Life. +/// +/// # Cache Eviction +/// Every method ([`get`](Subscriber::get), [`set`](Subscriber::set), [`remove`](Subscriber::remove)) causes the +/// cache to check for expired keys. This may lead to certain delays, esp. when dropping the contained values takes a +/// long time. +#[derive(Debug)] +pub struct TtlPolicy +where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + ttl_provider: Arc>, + expiration: AddressableHeap, + metric_expired: U64Counter, +} + +impl TtlPolicy +where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + /// Create new TTL policy. + pub fn new( + ttl_provider: Arc>, + name: &'static str, + metric_registry: &metric::Registry, + ) -> impl FnOnce(CallbackHandle) -> Self { + let metric_expired = metric_registry + .register_metric::( + "cache_ttl_expired", + "Number of entries that expired via TTL.", + ) + .recorder(&[("name", name)]); + + |mut callback_handle| { + callback_handle.execute_requests(vec![ChangeRequest::ensure_empty()]); + + Self { + ttl_provider, + expiration: Default::default(), + metric_expired, + } + } + } + + fn evict_expired(&mut self, now: Time) -> Vec> { + let mut requests = vec![]; + + while self + .expiration + .peek() + .map(|(_k, _, t)| *t <= now) + .unwrap_or_default() + { + let (k, _, _t) = self.expiration.pop().unwrap(); + self.metric_expired.inc(1); + requests.push(ChangeRequest::remove(k)); + } + + requests + } +} + +impl Subscriber for TtlPolicy +where + K: Clone + Eq + Debug + Hash + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + type K = K; + type V = V; + + fn get(&mut self, _k: &Self::K, now: Time) -> Vec> { + self.evict_expired(now) + } + + fn set( + &mut self, + k: &Self::K, + v: &Self::V, + now: Time, + ) -> Vec> { + let mut requests = self.evict_expired(now); + + if let Some(ttl) = self.ttl_provider.expires_in(k, v) { + if ttl.is_zero() { + requests.push(ChangeRequest::remove(k.clone())); + } + + match now.checked_add(ttl) { + Some(t) => { + self.expiration.insert(k.clone(), (), t); + } + None => { + // Still need to ensure that any current expiration is disabled + self.expiration.remove(k); + } + } + } else { + // Still need to ensure that any current expiration is disabled + self.expiration.remove(k); + }; + + requests + } + + fn remove(&mut self, k: &Self::K, now: Time) -> Vec> { + self.expiration.remove(k); + self.evict_expired(now) + } +} + +pub mod test_util { + //! Test utils for TTL policy. + use std::collections::HashMap; + + use parking_lot::Mutex; + + use super::*; + + /// [`TtlProvider`] for testing. + #[derive(Debug, Default)] + pub struct TestTtlProvider { + expires_in: Mutex>>, + } + + impl TestTtlProvider { + /// Create new, empty provider. + pub fn new() -> Self { + Self::default() + } + + /// Set TTL time for given key-value pair. + pub fn set_expires_in(&self, k: u8, v: String, d: Option) { + self.expires_in.lock().insert((k, v), d); + } + } + + impl TtlProvider for TestTtlProvider { + type K = u8; + type V = String; + + fn expires_in(&self, k: &Self::K, v: &Self::V) -> Option { + *self + .expires_in + .lock() + .get(&(*k, v.clone())) + .expect("expires_in value not mocked") + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + #[should_panic(expected = "expires_in value not mocked")] + fn test_panic_value_not_mocked() { + TestTtlProvider::new().expires_in(&1, &String::from("foo")); + } + + #[test] + fn test_mocking() { + let provider = TestTtlProvider::default(); + + provider.set_expires_in(1, String::from("a"), None); + provider.set_expires_in(1, String::from("b"), Some(Duration::from_secs(1))); + provider.set_expires_in(2, String::from("a"), Some(Duration::from_secs(2))); + + assert_eq!(provider.expires_in(&1, &String::from("a")), None,); + assert_eq!( + provider.expires_in(&1, &String::from("b")), + Some(Duration::from_secs(1)), + ); + assert_eq!( + provider.expires_in(&2, &String::from("a")), + Some(Duration::from_secs(2)), + ); + + // replace + provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(3))); + assert_eq!( + provider.expires_in(&1, &String::from("a")), + Some(Duration::from_secs(3)), + ); + } + } +} + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, time::Duration}; + + use iox_time::MockProvider; + use metric::{Observation, RawReporter}; + + use crate::backend::{policy::PolicyBackend, CacheBackend}; + + use super::{test_util::TestTtlProvider, *}; + + #[test] + fn test_never_ttl_provider() { + let provider = NeverTtlProvider::::default(); + assert_eq!(provider.expires_in(&1, &2), None); + } + + #[test] + fn test_constant_value_ttl_provider() { + let ttl = Some(Duration::from_secs(1)); + let provider = ConstantValueTtlProvider::::new(ttl); + assert_eq!(provider.expires_in(&1, &2), ttl); + } + + #[test] + fn test_optional_value_ttl_provider() { + let ttl_none = Some(Duration::from_secs(1)); + let ttl_some = Some(Duration::from_secs(2)); + let provider = OptionalValueTtlProvider::::new(ttl_none, ttl_some); + assert_eq!(provider.expires_in(&1, &None), ttl_none); + assert_eq!(provider.expires_in(&1, &Some(2)), ttl_some); + } + + #[test] + #[should_panic(expected = "inner backend is not empty")] + fn test_panic_inner_not_empty() { + let ttl_provider = Arc::new(TestTtlProvider::new()); + let metric_registry = metric::Registry::new(); + + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let mut backend: PolicyBackend = PolicyBackend::hashmap_backed(time_provider); + let policy_constructor = + TtlPolicy::new(Arc::clone(&ttl_provider) as _, "my_cache", &metric_registry); + backend.add_policy(|mut handle| { + handle.execute_requests(vec![ChangeRequest::set(1, String::from("foo"))]); + policy_constructor(handle) + }); + } + + #[test] + fn test_expires_single() { + let TestState { + mut backend, + metric_registry, + ttl_provider, + time_provider, + } = TestState::new(); + + assert_eq!(get_expired_metric(&metric_registry), 0); + + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1))); + backend.set(1, String::from("a")); + assert_eq!(backend.get(&1), Some(String::from("a"))); + + assert_eq!(get_expired_metric(&metric_registry), 0); + + time_provider.inc(Duration::from_secs(1)); + assert_eq!(backend.get(&1), None); + + assert_eq!(get_expired_metric(&metric_registry), 1); + } + + #[test] + fn test_overflow_expire() { + let ttl_provider = Arc::new(TestTtlProvider::new()); + let metric_registry = metric::Registry::new(); + + // init time provider at MAX! + let time_provider = Arc::new(MockProvider::new(Time::MAX)); + let mut backend: PolicyBackend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(TtlPolicy::new( + Arc::clone(&ttl_provider) as _, + "my_cache", + &metric_registry, + )); + + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::MAX)); + backend.set(1, String::from("a")); + assert_eq!(backend.get(&1), Some(String::from("a"))); + } + + #[test] + fn test_never_expire() { + let TestState { + mut backend, + ttl_provider, + time_provider, + .. + } = TestState::new(); + + ttl_provider.set_expires_in(1, String::from("a"), None); + backend.set(1, String::from("a")); + assert_eq!(backend.get(&1), Some(String::from("a"))); + + time_provider.inc(Duration::from_secs(1)); + assert_eq!(backend.get(&1), Some(String::from("a"))); + } + + #[test] + fn test_expiration_uses_key_and_value() { + let TestState { + mut backend, + ttl_provider, + time_provider, + .. + } = TestState::new(); + + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1))); + ttl_provider.set_expires_in(1, String::from("b"), Some(Duration::from_secs(4))); + ttl_provider.set_expires_in(2, String::from("a"), Some(Duration::from_secs(2))); + backend.set(1, String::from("b")); + + time_provider.inc(Duration::from_secs(3)); + assert_eq!(backend.get(&1), Some(String::from("b"))); + } + + #[test] + fn test_override_with_different_expiration() { + let TestState { + mut backend, + ttl_provider, + time_provider, + .. + } = TestState::new(); + + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1))); + backend.set(1, String::from("a")); + assert_eq!(backend.get(&1), Some(String::from("a"))); + + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(3))); + backend.set(1, String::from("a")); + + time_provider.inc(Duration::from_secs(2)); + assert_eq!(backend.get(&1), Some(String::from("a"))); + } + + #[test] + fn test_override_with_no_expiration() { + let TestState { + mut backend, + ttl_provider, + time_provider, + .. + } = TestState::new(); + + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1))); + backend.set(1, String::from("a")); + assert_eq!(backend.get(&1), Some(String::from("a"))); + + ttl_provider.set_expires_in(1, String::from("a"), None); + backend.set(1, String::from("a")); + + time_provider.inc(Duration::from_secs(2)); + assert_eq!(backend.get(&1), Some(String::from("a"))); + } + + #[test] + fn test_override_with_some_expiration() { + let TestState { + mut backend, + ttl_provider, + time_provider, + .. + } = TestState::new(); + + ttl_provider.set_expires_in(1, String::from("a"), None); + backend.set(1, String::from("a")); + assert_eq!(backend.get(&1), Some(String::from("a"))); + + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1))); + backend.set(1, String::from("a")); + + time_provider.inc(Duration::from_secs(2)); + assert_eq!(backend.get(&1), None); + } + + #[test] + fn test_override_with_overflow() { + let ttl_provider = Arc::new(TestTtlProvider::new()); + let metric_registry = metric::Registry::new(); + + // init time provider at nearly MAX! + let time_provider = Arc::new(MockProvider::new(Time::MAX - Duration::from_secs(2))); + let mut backend: PolicyBackend = + PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend.add_policy(TtlPolicy::new( + Arc::clone(&ttl_provider) as _, + "my_cache", + &metric_registry, + )); + + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1))); + backend.set(1, String::from("a")); + assert_eq!(backend.get(&1), Some(String::from("a"))); + + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(u64::MAX))); + backend.set(1, String::from("a")); + + time_provider.inc(Duration::from_secs(2)); + assert_eq!(backend.get(&1), Some(String::from("a"))); + } + + #[test] + fn test_readd_with_different_expiration() { + let TestState { + mut backend, + ttl_provider, + time_provider, + .. + } = TestState::new(); + + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1))); + backend.set(1, String::from("a")); + assert_eq!(backend.get(&1), Some(String::from("a"))); + + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(3))); + backend.remove(&1); + backend.set(1, String::from("a")); + + time_provider.inc(Duration::from_secs(2)); + assert_eq!(backend.get(&1), Some(String::from("a"))); + } + + #[test] + fn test_readd_with_no_expiration() { + let TestState { + mut backend, + ttl_provider, + time_provider, + .. + } = TestState::new(); + + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1))); + backend.set(1, String::from("a")); + assert_eq!(backend.get(&1), Some(String::from("a"))); + + ttl_provider.set_expires_in(1, String::from("a"), None); + backend.remove(&1); + backend.set(1, String::from("a")); + + time_provider.inc(Duration::from_secs(2)); + assert_eq!(backend.get(&1), Some(String::from("a"))); + } + + #[test] + fn test_update_cleans_multiple_keys() { + let TestState { + mut backend, + ttl_provider, + time_provider, + .. + } = TestState::new(); + + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1))); + ttl_provider.set_expires_in(2, String::from("b"), Some(Duration::from_secs(2))); + ttl_provider.set_expires_in(3, String::from("c"), Some(Duration::from_secs(2))); + ttl_provider.set_expires_in(4, String::from("d"), Some(Duration::from_secs(3))); + backend.set(1, String::from("a")); + backend.set(2, String::from("b")); + backend.set(3, String::from("c")); + backend.set(4, String::from("d")); + assert_eq!(backend.get(&1), Some(String::from("a"))); + assert_eq!(backend.get(&2), Some(String::from("b"))); + assert_eq!(backend.get(&3), Some(String::from("c"))); + assert_eq!(backend.get(&4), Some(String::from("d"))); + + time_provider.inc(Duration::from_secs(2)); + assert_eq!(backend.get(&1), None); + + { + let inner_ref = backend.inner_ref(); + let inner_backend = inner_ref + .as_any() + .downcast_ref::>() + .unwrap(); + assert!(!inner_backend.contains_key(&1)); + assert!(!inner_backend.contains_key(&2)); + assert!(!inner_backend.contains_key(&3)); + assert!(inner_backend.contains_key(&4)); + } + + assert_eq!(backend.get(&2), None); + assert_eq!(backend.get(&3), None); + assert_eq!(backend.get(&4), Some(String::from("d"))); + } + + #[test] + fn test_remove_expired_key() { + let TestState { + mut backend, + ttl_provider, + time_provider, + .. + } = TestState::new(); + + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1))); + backend.set(1, String::from("a")); + assert_eq!(backend.get(&1), Some(String::from("a"))); + + time_provider.inc(Duration::from_secs(1)); + backend.remove(&1); + assert_eq!(backend.get(&1), None); + } + + #[test] + fn test_expire_removed_key() { + let TestState { + mut backend, + ttl_provider, + time_provider, + .. + } = TestState::new(); + + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1))); + ttl_provider.set_expires_in(2, String::from("b"), Some(Duration::from_secs(2))); + backend.set(1, String::from("a")); + backend.remove(&1); + + time_provider.inc(Duration::from_secs(1)); + backend.set(2, String::from("b")); + assert_eq!(backend.get(&1), None); + assert_eq!(backend.get(&2), Some(String::from("b"))); + } + + #[test] + fn test_expire_immediately() { + let TestState { + mut backend, + ttl_provider, + .. + } = TestState::new(); + + ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(0))); + backend.set(1, String::from("a")); + + assert!(backend.is_empty()); + + assert_eq!(backend.get(&1), None); + } + + #[test] + fn test_generic_backend() { + use crate::backend::test_util::test_generic; + + test_generic(|| { + let ttl_provider = Arc::new(NeverTtlProvider::default()); + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let metric_registry = metric::Registry::new(); + let mut backend = PolicyBackend::hashmap_backed(time_provider); + backend.add_policy(TtlPolicy::new( + Arc::clone(&ttl_provider) as _, + "my_cache", + &metric_registry, + )); + backend + }); + } + + struct TestState { + backend: PolicyBackend, + metric_registry: metric::Registry, + ttl_provider: Arc, + time_provider: Arc, + } + + impl TestState { + fn new() -> Self { + let ttl_provider = Arc::new(TestTtlProvider::new()); + let time_provider = Arc::new(MockProvider::new(Time::MIN)); + let metric_registry = metric::Registry::new(); + + let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _); + backend.add_policy(TtlPolicy::new( + Arc::clone(&ttl_provider) as _, + "my_cache", + &metric_registry, + )); + + Self { + backend, + metric_registry, + ttl_provider, + time_provider, + } + } + } + + fn get_expired_metric(metric_registry: &metric::Registry) -> u64 { + let mut reporter = RawReporter::default(); + metric_registry.report(&mut reporter); + let observation = reporter + .metric("cache_ttl_expired") + .unwrap() + .observation(&[("name", "my_cache")]) + .unwrap(); + + if let Observation::U64Counter(c) = observation { + *c + } else { + panic!("Wrong observation type") + } + } +} diff --git a/cache_system/src/backend/test_util.rs b/cache_system/src/backend/test_util.rs new file mode 100644 index 0000000..21dae2c --- /dev/null +++ b/cache_system/src/backend/test_util.rs @@ -0,0 +1,112 @@ +use super::CacheBackend; + +/// Generic test set for [`Backend`]. +/// +/// The backend must NOT perform any pruning/deletions during the tests (even though backends are allowed to do that in +/// general). +pub fn test_generic(constructor: F) +where + B: CacheBackend, + F: Fn() -> B, +{ + test_get_empty(constructor()); + test_get_set(constructor()); + test_get_twice(constructor()); + test_override(constructor()); + test_set_remove_get(constructor()); + test_remove_empty(constructor()); + test_readd(constructor()); + test_is_empty(constructor()); +} + +/// Test GET on empty backend. +fn test_get_empty(mut backend: B) +where + B: CacheBackend, +{ + assert_eq!(backend.get(&1), None); +} + +/// Test GET and SET without any overrides. +fn test_get_set(mut backend: B) +where + B: CacheBackend, +{ + backend.set(1, String::from("a")); + backend.set(2, String::from("b")); + + assert_eq!(backend.get(&1), Some(String::from("a"))); + assert_eq!(backend.get(&2), Some(String::from("b"))); + assert_eq!(backend.get(&3), None); +} + +/// Test that a value can be retrieved multiple times. +fn test_get_twice(mut backend: B) +where + B: CacheBackend, +{ + backend.set(1, String::from("a")); + + assert_eq!(backend.get(&1), Some(String::from("a"))); + assert_eq!(backend.get(&1), Some(String::from("a"))); +} + +/// Test that setting a value twice w/o deletion overrides the existing value. +fn test_override(mut backend: B) +where + B: CacheBackend, +{ + backend.set(1, String::from("a")); + backend.set(1, String::from("b")); + + assert_eq!(backend.get(&1), Some(String::from("b"))); +} + +/// Test removal of on empty backend. +fn test_remove_empty(mut backend: B) +where + B: CacheBackend, +{ + backend.remove(&1); +} + +/// Test removal of existing values. +fn test_set_remove_get(mut backend: B) +where + B: CacheBackend, +{ + backend.set(1, String::from("a")); + backend.remove(&1); + + assert_eq!(backend.get(&1), None); +} + +/// Test setting a new value after removing it. +fn test_readd(mut backend: B) +where + B: CacheBackend, +{ + backend.set(1, String::from("a")); + backend.remove(&1); + backend.set(1, String::from("b")); + + assert_eq!(backend.get(&1), Some(String::from("b"))); +} + +/// Test `is_empty` check. +fn test_is_empty(mut backend: B) +where + B: CacheBackend, +{ + assert!(backend.is_empty()); + + backend.set(1, String::from("a")); + backend.set(2, String::from("b")); + assert!(!backend.is_empty()); + + backend.remove(&1); + assert!(!backend.is_empty()); + + backend.remove(&2); + assert!(backend.is_empty()); +} diff --git a/cache_system/src/cache/driver.rs b/cache_system/src/cache/driver.rs new file mode 100644 index 0000000..c0c9773 --- /dev/null +++ b/cache_system/src/cache/driver.rs @@ -0,0 +1,452 @@ +//! Main data structure, see [`CacheDriver`]. + +use crate::{ + backend::CacheBackend, + cancellation_safe_future::{CancellationSafeFuture, CancellationSafeFutureReceiver}, + loader::Loader, +}; +use async_trait::async_trait; +use futures::{ + channel::oneshot::{channel, Canceled, Sender}, + future::{BoxFuture, Shared}, + FutureExt, TryFutureExt, +}; +use observability_deps::tracing::debug; +use std::{collections::HashMap, fmt::Debug, future::Future, sync::Arc}; +use tracker::{LockMetrics, Mutex}; + +use super::{Cache, CacheGetStatus, CachePeekStatus}; + +/// Combine a [`CacheBackend`] and a [`Loader`] into a single [`Cache`] +#[derive(Debug)] +pub struct CacheDriver +where + B: CacheBackend + Send + 'static, + L: Loader, +{ + state: Arc>>, + loader: Arc, +} + +impl CacheDriver +where + B: CacheBackend + Send + 'static, + L: Loader, +{ + /// Create new, empty cache with given loader function. + pub fn new(loader: Arc, backend: B, metrics: &metric::Registry, name: &'static str) -> Self { + let metrics = Arc::new(LockMetrics::new( + metrics, + &[("what", "cache_driver_state"), ("cache", name)], + )); + + Self { + state: Arc::new(metrics.new_mutex(CacheState { + cached_entries: backend, + running_queries: HashMap::new(), + tag_counter: 0, + })), + loader, + } + } + + fn start_new_query( + state: &mut CacheState, + state_captured: Arc>>, + loader: Arc, + k: B::K, + extra: L::Extra, + ) -> ( + CancellationSafeFuture>, + SharedReceiver, + ) { + let (tx_main, rx_main) = channel(); + let receiver = rx_main + .map_ok(|v| Arc::new(Mutex::new(v))) + .map_err(Arc::new) + .boxed() + .shared(); + let (tx_set, rx_set) = channel(); + + // generate unique tag + let tag = state.tag_counter; + state.tag_counter += 1; + + // need to wrap the query into a `CancellationSafeFuture` so that it doesn't get cancelled when + // this very request is cancelled + let join_handle_receiver = CancellationSafeFutureReceiver::default(); + let k_captured = k.clone(); + let fut = async move { + let loader_fut = async move { + let submitter = ResultSubmitter::new(state_captured, k_captured.clone(), tag); + + // execute the loader + // If we panic here then `tx` will be dropped and the receivers will be + // notified. + let v = loader.load(k_captured, extra).await; + + // remove "running" state and store result + let was_running = submitter.submit(v.clone()); + + if !was_running { + // value was side-loaded, so we cannot populate `v`. Instead block this + // execution branch and wait for `rx_set` to deliver the side-loaded + // result. + loop { + tokio::task::yield_now().await; + } + } + + v + }; + + // prefer the side-loader + let v = futures::select_biased! { + maybe_v = rx_set.fuse() => { + match maybe_v { + Ok(v) => { + // data get side-loaded via `Cache::set`. In this case, we do + // NOT modify the state because there would be a lock-gap. The + // `set` function will do that for us instead. + v + } + Err(_) => { + // sender side is gone, very likely the cache is shutting down + debug!( + "Sender for side-loading data into running query gone.", + ); + return; + } + } + } + v = loader_fut.fuse() => v, + }; + + // broadcast result + // It's OK if the receiver side is gone. This might happen during shutdown + tx_main.send(v).ok(); + }; + let fut = CancellationSafeFuture::new(fut, join_handle_receiver.clone()); + + state.running_queries.insert( + k, + RunningQuery { + recv: receiver.clone(), + set: tx_set, + _join_handle: join_handle_receiver, + tag, + }, + ); + + (fut, receiver) + } +} + +#[async_trait] +impl Cache for CacheDriver +where + B: CacheBackend + Send, + L: Loader, +{ + type K = B::K; + type V = B::V; + type GetExtra = L::Extra; + type PeekExtra = (); + + async fn get_with_status( + &self, + k: Self::K, + extra: Self::GetExtra, + ) -> (Self::V, CacheGetStatus) { + // place state locking into its own scope so it doesn't leak into the generator (async + // function) + let (fut, receiver, status) = { + let mut state = self.state.lock(); + + // check if the entry has already been cached + if let Some(v) = state.cached_entries.get(&k) { + return (v, CacheGetStatus::Hit); + } + + // check if there is already a query for this key running + if let Some(running_query) = state.running_queries.get(&k) { + ( + None, + running_query.recv.clone(), + CacheGetStatus::MissAlreadyLoading, + ) + } else { + // requires new query + let (fut, receiver) = Self::start_new_query( + &mut state, + Arc::clone(&self.state), + Arc::clone(&self.loader), + k, + extra, + ); + (Some(fut), receiver, CacheGetStatus::Miss) + } + }; + + // try to run the loader future in this very task context to avoid spawning tokio tasks (which adds latency and + // overhead) + if let Some(fut) = fut { + fut.await; + } + + let v = retrieve_from_shared(receiver).await; + + (v, status) + } + + async fn peek_with_status( + &self, + k: Self::K, + _extra: Self::PeekExtra, + ) -> Option<(Self::V, CachePeekStatus)> { + // place state locking into its own scope so it doesn't leak into the generator (async + // function) + let (receiver, status) = { + let mut state = self.state.lock(); + + // check if the entry has already been cached + if let Some(v) = state.cached_entries.get(&k) { + return Some((v, CachePeekStatus::Hit)); + } + + // check if there is already a query for this key running + if let Some(running_query) = state.running_queries.get(&k) { + ( + running_query.recv.clone(), + CachePeekStatus::MissAlreadyLoading, + ) + } else { + return None; + } + }; + + let v = retrieve_from_shared(receiver).await; + + Some((v, status)) + } + + async fn set(&self, k: Self::K, v: Self::V) { + let maybe_join_handle = { + let mut state = self.state.lock(); + + let maybe_recv = if let Some(running_query) = state.running_queries.remove(&k) { + // it's OK when the receiver side is gone (likely panicked) + running_query.set.send(v.clone()).ok(); + + // When we side-load data into the running task, the task does NOT modify the + // backend, so we have to do that. The reason for not letting the task feed the + // side-loaded data back into `cached_entries` is that we would need to drop the + // state lock here before the task could acquire it, leading to a lock gap. + Some(running_query.recv) + } else { + None + }; + + state.cached_entries.set(k, v); + + maybe_recv + }; + + // drive running query (if any) to completion + if let Some(recv) = maybe_join_handle { + // we do not care if the query died (e.g. due to a panic) + recv.await.ok(); + } + } +} + +impl Drop for CacheDriver +where + B: CacheBackend + Send, + L: Loader, +{ + fn drop(&mut self) { + for _ in self.state.lock().running_queries.drain() {} + } +} + +/// Helper to submit results of running queries. +/// +/// Ensures that running query is removed when dropped (e.g. during panic). +struct ResultSubmitter +where + B: CacheBackend, +{ + state: Arc>>, + tag: u64, + k: Option, + v: Option, +} + +impl ResultSubmitter +where + B: CacheBackend, +{ + fn new(state: Arc>>, k: B::K, tag: u64) -> Self { + Self { + state, + tag, + k: Some(k), + v: None, + } + } + + /// Submit value. + /// + /// Returns `true` if this very query was running. + fn submit(mut self, v: B::V) -> bool { + assert!(self.v.is_none()); + self.v = Some(v); + self.finalize() + } + + /// Finalize request. + /// + /// Returns `true` if this very query was running. + fn finalize(&mut self) -> bool { + let k = self.k.take().expect("finalized twice"); + let mut state = self.state.lock(); + + match state.running_queries.get(&k) { + Some(running_query) if running_query.tag == self.tag => { + state.running_queries.remove(&k); + + if let Some(v) = self.v.take() { + // this very query is in charge of the key, so store in in the + // underlying cache + state.cached_entries.set(k, v); + } + + true + } + _ => { + // This query is actually not really running any longer but got + // shut down, e.g. due to side loading. Do NOT store the + // generated value in the underlying cache. + + false + } + } + } +} + +impl Drop for ResultSubmitter +where + B: CacheBackend, +{ + fn drop(&mut self) { + if self.k.is_some() { + // not finalized yet + self.finalize(); + } + } +} + +/// A [`tokio::sync::oneshot::Receiver`] that can be cloned. +/// +/// The types are: +/// +/// - `Arc>`: Ensures that we can clone `V` without requiring `V: Sync`. At the same time +/// the reference to `V` (i.e. the `Arc`) must be cloneable for `Shared` +/// - `Arc`: Is required because `RecvError` is not `Clone` but `Shared` requires that. +/// - `BoxFuture`: The transformation from `Result` to `Result>, +/// Arc>` results in a kinda messy type and we wanna erase that. +/// - `Shared`: Allow the receiver to be cloned and be awaited from multiple places. +type SharedReceiver = Shared>, Arc>>>; + +/// Retrieve data from shared receiver. +async fn retrieve_from_shared(receiver: SharedReceiver) -> V +where + V: Clone + Send, +{ + receiver + .await + .expect("cache loader panicked, see logs") + .lock() + .clone() +} + +/// State for coordinating the execution of a single running query. +#[derive(Debug)] +struct RunningQuery { + /// A receiver that can await the result as well. + recv: SharedReceiver, + + /// A sender that enables setting entries while the query is running. + #[allow(dead_code)] + set: Sender, + + /// A handle for the task that is currently executing the query. + /// + /// The handle can be used to abort the running query, e.g. when dropping the cache. + /// + /// This is "dead code" because we only store it to keep the future alive. There's no direct interaction. + _join_handle: CancellationSafeFutureReceiver<()>, + + /// Tag so that queries for the same key (e.g. when starting, side-loading, starting again) can + /// be told apart. + tag: u64, +} + +/// Inner cache state that is usually guarded by a lock. +/// +/// The state parts must be updated in a consistent manner, i.e. while using the same lock guard. +#[derive(Debug)] +struct CacheState +where + B: CacheBackend, +{ + /// Cached entires (i.e. queries completed). + cached_entries: B, + + /// Currently running queries indexed by cache key. + running_queries: HashMap>, + + /// Tag counter for running queries. + tag_counter: u64, +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::{ + cache::test_util::{run_test_generic, TestAdapter}, + loader::test_util::TestLoader, + }; + + use super::*; + + #[tokio::test] + async fn test_generic() { + run_test_generic(MyTestAdapter).await; + } + + struct MyTestAdapter; + + impl TestAdapter for MyTestAdapter { + type GetExtra = bool; + type PeekExtra = (); + type Cache = CacheDriver, TestLoader>; + + fn construct(&self, loader: Arc) -> Arc { + Arc::new(CacheDriver::new( + Arc::clone(&loader) as _, + HashMap::new(), + &metric::Registry::default(), + "test", + )) + } + + fn get_extra(&self, inner: bool) -> Self::GetExtra { + inner + } + + fn peek_extra(&self) -> Self::PeekExtra {} + } +} diff --git a/cache_system/src/cache/metrics.rs b/cache_system/src/cache/metrics.rs new file mode 100644 index 0000000..c72364a --- /dev/null +++ b/cache_system/src/cache/metrics.rs @@ -0,0 +1,718 @@ +//! Metrics instrumentation for [`Cache`]s. +use std::{fmt::Debug, sync::Arc}; + +use async_trait::async_trait; +use iox_time::{Time, TimeProvider}; +use metric::{Attributes, DurationHistogram, U64Counter}; +use observability_deps::tracing::warn; +use trace::span::{Span, SpanRecorder}; + +use super::{Cache, CacheGetStatus, CachePeekStatus}; + +/// Struct containing all the metrics +#[derive(Debug)] +struct Metrics { + time_provider: Arc, + metric_get_hit: DurationHistogram, + metric_get_miss: DurationHistogram, + metric_get_miss_already_loading: DurationHistogram, + metric_get_cancelled: DurationHistogram, + metric_peek_hit: DurationHistogram, + metric_peek_miss: DurationHistogram, + metric_peek_miss_already_loading: DurationHistogram, + metric_peek_cancelled: DurationHistogram, + metric_set: U64Counter, +} + +impl Metrics { + fn new( + name: &'static str, + time_provider: Arc, + metric_registry: &metric::Registry, + ) -> Self { + let attributes = Attributes::from(&[("name", name)]); + + let mut attributes_get = attributes.clone(); + let metric_get = metric_registry + .register_metric::("iox_cache_get", "Cache GET requests"); + + attributes_get.insert("status", "hit"); + let metric_get_hit = metric_get.recorder(attributes_get.clone()); + + attributes_get.insert("status", "miss"); + let metric_get_miss = metric_get.recorder(attributes_get.clone()); + + attributes_get.insert("status", "miss_already_loading"); + let metric_get_miss_already_loading = metric_get.recorder(attributes_get.clone()); + + attributes_get.insert("status", "cancelled"); + let metric_get_cancelled = metric_get.recorder(attributes_get); + + let mut attributes_peek = attributes.clone(); + let metric_peek = metric_registry + .register_metric::("iox_cache_peek", "Cache PEEK requests"); + + attributes_peek.insert("status", "hit"); + let metric_peek_hit = metric_peek.recorder(attributes_peek.clone()); + + attributes_peek.insert("status", "miss"); + let metric_peek_miss = metric_peek.recorder(attributes_peek.clone()); + + attributes_peek.insert("status", "miss_already_loading"); + let metric_peek_miss_already_loading = metric_peek.recorder(attributes_peek.clone()); + + attributes_peek.insert("status", "cancelled"); + let metric_peek_cancelled = metric_peek.recorder(attributes_peek); + + let metric_set = metric_registry + .register_metric::("iox_cache_set", "Cache SET requests.") + .recorder(attributes); + + Self { + time_provider, + metric_get_hit, + metric_get_miss, + metric_get_miss_already_loading, + metric_get_cancelled, + metric_peek_hit, + metric_peek_miss, + metric_peek_miss_already_loading, + metric_peek_cancelled, + metric_set, + } + } +} + +/// Wraps given cache with metrics. +#[derive(Debug)] +pub struct CacheWithMetrics +where + C: Cache, +{ + inner: C, + metrics: Metrics, +} + +impl CacheWithMetrics +where + C: Cache, +{ + /// Create new metrics wrapper around given cache. + pub fn new( + inner: C, + name: &'static str, + time_provider: Arc, + metric_registry: &metric::Registry, + ) -> Self { + Self { + inner, + metrics: Metrics::new(name, time_provider, metric_registry), + } + } +} + +#[async_trait] +impl Cache for CacheWithMetrics +where + C: Cache, +{ + type K = C::K; + type V = C::V; + type GetExtra = (C::GetExtra, Option); + type PeekExtra = (C::PeekExtra, Option); + + async fn get_with_status( + &self, + k: Self::K, + extra: Self::GetExtra, + ) -> (Self::V, CacheGetStatus) { + let (extra, span) = extra; + let mut set_on_drop = SetGetMetricOnDrop::new(&self.metrics, span); + let (v, status) = self.inner.get_with_status(k, extra).await; + set_on_drop.status = Some(status); + + (v, status) + } + + async fn peek_with_status( + &self, + k: Self::K, + extra: Self::PeekExtra, + ) -> Option<(Self::V, CachePeekStatus)> { + let (extra, span) = extra; + let mut set_on_drop = SetPeekMetricOnDrop::new(&self.metrics, span); + let res = self.inner.peek_with_status(k, extra).await; + set_on_drop.status = Some(res.as_ref().map(|(_v, status)| *status)); + + res + } + + async fn set(&self, k: Self::K, v: Self::V) { + self.inner.set(k, v).await; + self.metrics.metric_set.inc(1); + } +} + +/// Helper that set's GET metrics on drop depending on the `status`. +/// +/// A drop might happen due to completion (in which case the `status` should be set) or if the future is cancelled (in +/// which case the `status` is `None`). +struct SetGetMetricOnDrop<'a> { + metrics: &'a Metrics, + t_start: Time, + status: Option, + span_recorder: SpanRecorder, +} + +impl<'a> SetGetMetricOnDrop<'a> { + fn new(metrics: &'a Metrics, span: Option) -> Self { + let t_start = metrics.time_provider.now(); + + Self { + metrics, + t_start, + status: None, + span_recorder: SpanRecorder::new(span), + } + } +} + +impl<'a> Drop for SetGetMetricOnDrop<'a> { + fn drop(&mut self) { + let t_end = self.metrics.time_provider.now(); + + match t_end.checked_duration_since(self.t_start) { + Some(duration) => { + match self.status { + Some(CacheGetStatus::Hit) => &self.metrics.metric_get_hit, + Some(CacheGetStatus::Miss) => &self.metrics.metric_get_miss, + Some(CacheGetStatus::MissAlreadyLoading) => { + &self.metrics.metric_get_miss_already_loading + } + None => &self.metrics.metric_get_cancelled, + } + .record(duration); + } + None => { + warn!("Clock went backwards, not recording cache GET duration"); + } + } + + if let Some(status) = self.status { + self.span_recorder.ok(status.name()); + } + } +} + +/// Helper that set's PEEK metrics on drop depending on the `status`. +/// +/// A drop might happen due to completion (in which case the `status` should be set) or if the future is cancelled (in +/// which case the `status` is `None`). +struct SetPeekMetricOnDrop<'a> { + metrics: &'a Metrics, + t_start: Time, + status: Option>, + span_recorder: SpanRecorder, +} + +impl<'a> SetPeekMetricOnDrop<'a> { + fn new(metrics: &'a Metrics, span: Option) -> Self { + let t_start = metrics.time_provider.now(); + + Self { + metrics, + t_start, + status: None, + span_recorder: SpanRecorder::new(span), + } + } +} + +impl<'a> Drop for SetPeekMetricOnDrop<'a> { + fn drop(&mut self) { + let t_end = self.metrics.time_provider.now(); + + match t_end.checked_duration_since(self.t_start) { + Some(duration) => { + match self.status { + Some(Some(CachePeekStatus::Hit)) => &self.metrics.metric_peek_hit, + Some(Some(CachePeekStatus::MissAlreadyLoading)) => { + &self.metrics.metric_peek_miss_already_loading + } + Some(None) => &self.metrics.metric_peek_miss, + None => &self.metrics.metric_peek_cancelled, + } + .record(duration); + } + None => { + warn!("Clock went backwards, not recording cache PEEK duration"); + } + } + + if let Some(status) = self.status { + self.span_recorder + .ok(status.map(|status| status.name()).unwrap_or("miss")); + } + } +} + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, time::Duration}; + + use futures::{stream::FuturesUnordered, StreamExt}; + use iox_time::{MockProvider, Time}; + use metric::{HistogramObservation, Observation, RawReporter}; + use tokio::sync::Barrier; + use trace::{span::SpanStatus, RingBufferTraceCollector}; + + use crate::{ + cache::{ + driver::CacheDriver, + test_util::{run_test_generic, TestAdapter}, + }, + loader::test_util::TestLoader, + test_util::{AbortAndWaitExt, EnsurePendingExt}, + }; + + use super::*; + + #[tokio::test] + async fn test_generic() { + run_test_generic(MyTestAdapter).await; + } + + struct MyTestAdapter; + + impl TestAdapter for MyTestAdapter { + type GetExtra = (bool, Option); + type PeekExtra = ((), Option); + type Cache = CacheWithMetrics, TestLoader>>; + + fn construct(&self, loader: Arc) -> Arc { + TestMetricsCache::new_with_loader(loader).cache + } + + fn get_extra(&self, inner: bool) -> Self::GetExtra { + (inner, None) + } + + fn peek_extra(&self) -> Self::PeekExtra { + ((), None) + } + } + + #[tokio::test] + async fn test_get() { + let test_cache = TestMetricsCache::new(); + + let traces = Arc::new(RingBufferTraceCollector::new(1_000)); + + let mut reporter = RawReporter::default(); + test_cache.metric_registry.report(&mut reporter); + + for status in ["hit", "miss", "miss_already_loading", "cancelled"] { + let hist = get_metric_cache_get(&reporter, status); + assert_eq!(hist.sample_count(), 0); + assert_eq!(hist.total, Duration::from_secs(0)); + } + + test_cache.loader.block_global(); + + let barrier_pending_1 = Arc::new(Barrier::new(2)); + let barrier_pending_1_captured = Arc::clone(&barrier_pending_1); + let traces_captured = Arc::clone(&traces); + let cache_captured = Arc::clone(&test_cache.cache); + let join_handle_1 = tokio::task::spawn(async move { + cache_captured + .get( + 1, + ( + true, + Some(Span::root("miss", Arc::clone(&traces_captured) as _)), + ), + ) + .ensure_pending(barrier_pending_1_captured) + .await + }); + + barrier_pending_1.wait().await; + let d1 = Duration::from_secs(1); + test_cache.time_provider.inc(d1); + let barrier_pending_2 = Arc::new(Barrier::new(2)); + let barrier_pending_2_captured = Arc::clone(&barrier_pending_2); + let traces_captured = Arc::clone(&traces); + let cache_captured = Arc::clone(&test_cache.cache); + let n_miss_already_loading = 10; + let join_handle_2 = tokio::task::spawn(async move { + (0..n_miss_already_loading) + .map(|_| { + cache_captured.get( + 1, + ( + true, + Some(Span::root( + "miss_already_loading", + Arc::clone(&traces_captured) as _, + )), + ), + ) + }) + .collect::>() + .collect::>() + .ensure_pending(barrier_pending_2_captured) + .await + }); + + barrier_pending_2.wait().await; + let d2 = Duration::from_secs(3); + test_cache.time_provider.inc(d2); + test_cache.loader.mock_next(1, "v".into()); + test_cache.loader.unblock_global(); + + join_handle_1.await.unwrap(); + join_handle_2.await.unwrap(); + + test_cache.loader.block_global(); + test_cache.time_provider.inc(Duration::from_secs(10)); + let n_hit = 100; + for _ in 0..n_hit { + test_cache + .cache + .get(1, (true, Some(Span::root("hit", Arc::clone(&traces) as _)))) + .await; + } + + let n_cancelled = 200; + let barrier_pending_3 = Arc::new(Barrier::new(2)); + let barrier_pending_3_captured = Arc::clone(&barrier_pending_3); + let traces_captured = Arc::clone(&traces); + let cache_captured = Arc::clone(&test_cache.cache); + let join_handle_3 = tokio::task::spawn(async move { + (0..n_cancelled) + .map(|_| { + cache_captured.get( + 2, + ( + true, + Some(Span::root("cancelled", Arc::clone(&traces_captured) as _)), + ), + ) + }) + .collect::>() + .collect::>() + .ensure_pending(barrier_pending_3_captured) + .await + }); + + barrier_pending_3.wait().await; + let d3 = Duration::from_secs(20); + test_cache.time_provider.inc(d3); + join_handle_3.abort_and_wait().await; + + let mut reporter = RawReporter::default(); + test_cache.metric_registry.report(&mut reporter); + + let hist = get_metric_cache_get(&reporter, "hit"); + assert_eq!(hist.sample_count(), n_hit); + // "hit"s are instant because there's no lock contention + assert_eq!(hist.total, Duration::from_secs(0)); + + let hist = get_metric_cache_get(&reporter, "miss"); + let n = 1; + assert_eq!(hist.sample_count(), n); + assert_eq!(hist.total, (n as u32) * (d1 + d2)); + + let hist = get_metric_cache_get(&reporter, "miss_already_loading"); + assert_eq!(hist.sample_count(), n_miss_already_loading); + assert_eq!(hist.total, (n_miss_already_loading as u32) * d2); + + let hist = get_metric_cache_get(&reporter, "cancelled"); + assert_eq!(hist.sample_count(), n_cancelled); + assert_eq!(hist.total, (n_cancelled as u32) * d3); + + // check spans + assert_n_spans(&traces, "hit", SpanStatus::Ok, n_hit as usize); + assert_n_spans(&traces, "miss", SpanStatus::Ok, 1); + assert_n_spans( + &traces, + "miss_already_loading", + SpanStatus::Ok, + n_miss_already_loading as usize, + ); + assert_n_spans( + &traces, + "cancelled", + SpanStatus::Unknown, + n_cancelled as usize, + ); + } + + #[tokio::test] + async fn test_peek() { + let test_cache = TestMetricsCache::new(); + + let traces = Arc::new(RingBufferTraceCollector::new(1_000)); + + let mut reporter = RawReporter::default(); + test_cache.metric_registry.report(&mut reporter); + + for status in ["hit", "miss", "miss_already_loading", "cancelled"] { + let hist = get_metric_cache_peek(&reporter, status); + assert_eq!(hist.sample_count(), 0); + assert_eq!(hist.total, Duration::from_secs(0)); + } + + test_cache.loader.block_global(); + + test_cache + .cache + .peek(1, ((), Some(Span::root("miss", Arc::clone(&traces) as _)))) + .await; + + let barrier_pending_1 = Arc::new(Barrier::new(2)); + let barrier_pending_1_captured = Arc::clone(&barrier_pending_1); + let cache_captured = Arc::clone(&test_cache.cache); + let join_handle_1 = tokio::task::spawn(async move { + cache_captured + .get(1, (true, None)) + .ensure_pending(barrier_pending_1_captured) + .await + }); + + barrier_pending_1.wait().await; + let d1 = Duration::from_secs(1); + test_cache.time_provider.inc(d1); + let barrier_pending_2 = Arc::new(Barrier::new(2)); + let barrier_pending_2_captured = Arc::clone(&barrier_pending_2); + let traces_captured = Arc::clone(&traces); + let cache_captured = Arc::clone(&test_cache.cache); + let n_miss_already_loading = 10; + let join_handle_2 = tokio::task::spawn(async move { + (0..n_miss_already_loading) + .map(|_| { + cache_captured.peek( + 1, + ( + (), + Some(Span::root( + "miss_already_loading", + Arc::clone(&traces_captured) as _, + )), + ), + ) + }) + .collect::>() + .collect::>() + .ensure_pending(barrier_pending_2_captured) + .await + }); + + barrier_pending_2.wait().await; + let d2 = Duration::from_secs(3); + test_cache.time_provider.inc(d2); + test_cache.loader.mock_next(1, "v".into()); + test_cache.loader.unblock_global(); + + join_handle_1.await.unwrap(); + join_handle_2.await.unwrap(); + + test_cache.loader.block_global(); + test_cache.time_provider.inc(Duration::from_secs(10)); + let n_hit = 100; + for _ in 0..n_hit { + test_cache + .cache + .peek(1, ((), Some(Span::root("hit", Arc::clone(&traces) as _)))) + .await; + } + + let n_cancelled = 200; + let barrier_pending_3 = Arc::new(Barrier::new(2)); + let barrier_pending_3_captured = Arc::clone(&barrier_pending_3); + let cache_captured = Arc::clone(&test_cache.cache); + tokio::task::spawn(async move { + cache_captured + .get(2, (true, None)) + .ensure_pending(barrier_pending_3_captured) + .await + }); + barrier_pending_3.wait().await; + let barrier_pending_4 = Arc::new(Barrier::new(2)); + let barrier_pending_4_captured = Arc::clone(&barrier_pending_4); + let traces_captured = Arc::clone(&traces); + let cache_captured = Arc::clone(&test_cache.cache); + let join_handle_3 = tokio::task::spawn(async move { + (0..n_cancelled) + .map(|_| { + cache_captured.peek( + 2, + ( + (), + Some(Span::root("cancelled", Arc::clone(&traces_captured) as _)), + ), + ) + }) + .collect::>() + .collect::>() + .ensure_pending(barrier_pending_4_captured) + .await + }); + + barrier_pending_4.wait().await; + let d3 = Duration::from_secs(20); + test_cache.time_provider.inc(d3); + join_handle_3.abort_and_wait().await; + + let mut reporter = RawReporter::default(); + test_cache.metric_registry.report(&mut reporter); + + let hist = get_metric_cache_peek(&reporter, "hit"); + assert_eq!(hist.sample_count(), n_hit); + // "hit"s are instant because there's no lock contention + assert_eq!(hist.total, Duration::from_secs(0)); + + let hist = get_metric_cache_peek(&reporter, "miss"); + let n = 1; + assert_eq!(hist.sample_count(), n); + // "miss"es are instant + assert_eq!(hist.total, Duration::from_secs(0)); + + let hist = get_metric_cache_peek(&reporter, "miss_already_loading"); + assert_eq!(hist.sample_count(), n_miss_already_loading); + assert_eq!(hist.total, (n_miss_already_loading as u32) * d2); + + let hist = get_metric_cache_peek(&reporter, "cancelled"); + assert_eq!(hist.sample_count(), n_cancelled); + assert_eq!(hist.total, (n_cancelled as u32) * d3); + + // check spans + assert_n_spans(&traces, "hit", SpanStatus::Ok, n_hit as usize); + assert_n_spans(&traces, "miss", SpanStatus::Ok, 1); + assert_n_spans( + &traces, + "miss_already_loading", + SpanStatus::Ok, + n_miss_already_loading as usize, + ); + assert_n_spans( + &traces, + "cancelled", + SpanStatus::Unknown, + n_cancelled as usize, + ); + } + + #[tokio::test] + async fn test_set() { + let test_cache = TestMetricsCache::new(); + + let mut reporter = RawReporter::default(); + test_cache.metric_registry.report(&mut reporter); + assert_eq!( + reporter + .metric("iox_cache_set") + .unwrap() + .observation(&[("name", "test")]) + .unwrap(), + &Observation::U64Counter(0) + ); + + test_cache.cache.set(1, String::from("foo")).await; + + let mut reporter = RawReporter::default(); + test_cache.metric_registry.report(&mut reporter); + assert_eq!( + reporter + .metric("iox_cache_set") + .unwrap() + .observation(&[("name", "test")]) + .unwrap(), + &Observation::U64Counter(1) + ); + } + + struct TestMetricsCache { + loader: Arc, + time_provider: Arc, + metric_registry: metric::Registry, + cache: Arc, TestLoader>>>, + } + + impl TestMetricsCache { + fn new() -> Self { + Self::new_with_loader(Arc::new(TestLoader::default())) + } + + fn new_with_loader(loader: Arc) -> Self { + let inner = CacheDriver::new( + Arc::clone(&loader) as _, + HashMap::new(), + &metric::Registry::default(), + "test", + ); + let time_provider = + Arc::new(MockProvider::new(Time::from_timestamp_millis(0).unwrap())); + let metric_registry = metric::Registry::new(); + let cache = Arc::new(CacheWithMetrics::new( + inner, + "test", + Arc::clone(&time_provider) as _, + &metric_registry, + )); + + Self { + loader, + time_provider, + metric_registry, + cache, + } + } + } + + fn get_metric_cache_get( + reporter: &RawReporter, + status: &'static str, + ) -> HistogramObservation { + if let Observation::DurationHistogram(hist) = reporter + .metric("iox_cache_get") + .unwrap() + .observation(&[("name", "test"), ("status", status)]) + .unwrap() + { + hist.clone() + } else { + panic!("Wrong observation type"); + } + } + + fn get_metric_cache_peek( + reporter: &RawReporter, + status: &'static str, + ) -> HistogramObservation { + if let Observation::DurationHistogram(hist) = reporter + .metric("iox_cache_peek") + .unwrap() + .observation(&[("name", "test"), ("status", status)]) + .unwrap() + { + hist.clone() + } else { + panic!("Wrong observation type"); + } + } + + fn assert_n_spans( + traces: &RingBufferTraceCollector, + name: &'static str, + status: SpanStatus, + expected: usize, + ) { + let actual = traces + .spans() + .into_iter() + .filter(|span| (span.name == name) && (span.status == status)) + .count(); + assert_eq!(actual, expected); + } +} diff --git a/cache_system/src/cache/mod.rs b/cache_system/src/cache/mod.rs new file mode 100644 index 0000000..ba3d541 --- /dev/null +++ b/cache_system/src/cache/mod.rs @@ -0,0 +1,167 @@ +//! Top-level trait ([`Cache`]) that provides a fully functional cache. +//! +//! Caches usually combine a [backend](crate::backend) with a [loader](crate::loader). The easiest way to achieve that +//! is to use [`CacheDriver`](crate::cache::driver::CacheDriver). Caches might also wrap inner caches to provide certain +//! extra functionality like metrics. +use std::{fmt::Debug, hash::Hash}; + +use async_trait::async_trait; + +pub mod driver; +pub mod metrics; + +#[cfg(test)] +mod test_util; + +/// Status of a [`Cache`] [GET](Cache::get_with_status) request. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CacheGetStatus { + /// The requested entry was present in the storage backend. + Hit, + + /// The requested entry was NOT present in the storage backend and the loader had no previous query running. + Miss, + + /// The requested entry was NOT present in the storage backend, but there was already a loader query running for + /// this particular key. + MissAlreadyLoading, +} + +impl CacheGetStatus { + /// Get human and machine readable name. + pub fn name(&self) -> &'static str { + match self { + Self::Hit => "hit", + Self::Miss => "miss", + Self::MissAlreadyLoading => "miss_already_loading", + } + } +} + +/// Status of a [`Cache`] [PEEK](Cache::peek_with_status) request. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CachePeekStatus { + /// The requested entry was present in the storage backend. + Hit, + + /// The requested entry was NOT present in the storage backend, but there was already a loader query running for + /// this particular key. + MissAlreadyLoading, +} + +impl CachePeekStatus { + /// Get human and machine redable name. + pub fn name(&self) -> &'static str { + match self { + Self::Hit => "hit", + Self::MissAlreadyLoading => "miss_already_loading", + } + } +} + +/// High-level cache implementation. +/// +/// # Concurrency +/// +/// Multiple cache requests for different keys can run at the same time. When data is requested for +/// the same key the underlying loader will only be polled once, even when the requests are made +/// while the loader is still running. +/// +/// # Cancellation +/// +/// Canceling a [`get`](Self::get) request will NOT cancel the underlying loader. The data will +/// still be cached. +/// +/// # Panic +/// +/// If the underlying loader panics, all currently running [`get`](Self::get) requests will panic. +/// The data will NOT be cached. +#[async_trait] +pub trait Cache: Debug + Send + Sync + 'static { + /// Cache key. + type K: Clone + Eq + Hash + Debug + Ord + Send + 'static; + + /// Cache value. + type V: Clone + Debug + Send + 'static; + + /// Extra data that is provided during [`GET`](Self::get) but that is NOT part of the cache key. + type GetExtra: Debug + Send + 'static; + + /// Extra data that is provided during [`PEEK`](Self::peek) but that is NOT part of the cache key. + type PeekExtra: Debug + Send + 'static; + + /// Get value from cache. + /// + /// Note that `extra` is only used if the key is missing from the storage backend and no loader query is running yet. + async fn get(&self, k: Self::K, extra: Self::GetExtra) -> Self::V { + self.get_with_status(k, extra).await.0 + } + + /// Get value from cache and the [status](CacheGetStatus). + /// + /// Note that `extra` is only used if the key is missing from the storage backend and no loader query is running yet. + async fn get_with_status(&self, k: Self::K, extra: Self::GetExtra) + -> (Self::V, CacheGetStatus); + + /// Peek value from cache. + /// + /// In contrast to [`get`](Self::get) this will only return a value if there is a stored value or the value loading + /// is already in progress. This will NOT start a new loading task. + /// + /// Note that `extra` is only used if the key is missing from the storage backend and no loader query is running yet. + async fn peek(&self, k: Self::K, extra: Self::PeekExtra) -> Option { + self.peek_with_status(k, extra).await.map(|(v, _status)| v) + } + + /// Peek value from cache and the [status](CachePeekStatus). + /// + /// In contrast to [`get_with_status`](Self::get_with_status) this will only return a value if there is a stored + /// value or the value loading is already in progress. This will NOT start a new loading task. + /// + /// Note that `extra` is only used if the key is missing from the storage backend and no loader query is running yet. + async fn peek_with_status( + &self, + k: Self::K, + extra: Self::PeekExtra, + ) -> Option<(Self::V, CachePeekStatus)>; + + /// Side-load an entry into the cache. + /// + /// This will also complete a currently running request for this key. + async fn set(&self, k: Self::K, v: Self::V); +} + +#[async_trait] +impl Cache + for Box> +where + K: Clone + Eq + Hash + Debug + Ord + Send + 'static, + V: Clone + Debug + Send + 'static, + GetExtra: Debug + Send + 'static, + PeekExtra: Debug + Send + 'static, +{ + type K = K; + type V = V; + type GetExtra = GetExtra; + type PeekExtra = PeekExtra; + + async fn get_with_status( + &self, + k: Self::K, + extra: Self::GetExtra, + ) -> (Self::V, CacheGetStatus) { + self.as_ref().get_with_status(k, extra).await + } + + async fn peek_with_status( + &self, + k: Self::K, + extra: Self::PeekExtra, + ) -> Option<(Self::V, CachePeekStatus)> { + self.as_ref().peek_with_status(k, extra).await + } + + async fn set(&self, k: Self::K, v: Self::V) { + self.as_ref().set(k, v).await + } +} diff --git a/cache_system/src/cache/test_util.rs b/cache_system/src/cache/test_util.rs new file mode 100644 index 0000000..b149eec --- /dev/null +++ b/cache_system/src/cache/test_util.rs @@ -0,0 +1,462 @@ +use std::{sync::Arc, time::Duration}; + +use tokio::sync::Barrier; + +use crate::{ + cache::{CacheGetStatus, CachePeekStatus}, + loader::test_util::TestLoader, + test_util::{AbortAndWaitExt, EnsurePendingExt}, +}; + +use super::Cache; + +/// Interface between generic tests and a concrete cache type. +pub trait TestAdapter: Send + Sync + 'static { + /// Extra information for GET. + type GetExtra: Send; + + /// Extra information for PEEK. + type PeekExtra: Send; + + /// Cache type. + type Cache: Cache; + + /// Create new cache with given loader. + fn construct(&self, loader: Arc) -> Arc; + + /// Build [`GetExtra`](Self::GetExtra). + /// + /// Must contain a [`bool`] payload that is later included into the value string for testing purposes. + fn get_extra(&self, inner: bool) -> Self::GetExtra; + + /// Build [`PeekExtra`](Self::PeekExtra). + fn peek_extra(&self) -> Self::PeekExtra; +} + +/// Setup test. +fn setup(adapter: &T) -> (Arc, Arc) +where + T: TestAdapter, +{ + let loader = Arc::new(TestLoader::default()); + let cache = adapter.construct(Arc::clone(&loader)); + (cache, loader) +} + +pub async fn run_test_generic(adapter: T) +where + T: TestAdapter, +{ + let adapter = Arc::new(adapter); + + test_answers_are_correct(Arc::clone(&adapter)).await; + test_linear_memory(Arc::clone(&adapter)).await; + test_concurrent_query_loads_once(Arc::clone(&adapter)).await; + test_queries_are_parallelized(Arc::clone(&adapter)).await; + test_cancel_request(Arc::clone(&adapter)).await; + test_panic_request(Arc::clone(&adapter)).await; + test_drop_cancels_loader(Arc::clone(&adapter)).await; + test_set_before_request(Arc::clone(&adapter)).await; + test_set_during_request(Arc::clone(&adapter)).await; +} + +async fn test_answers_are_correct(adapter: Arc) +where + T: TestAdapter, +{ + let (cache, loader) = setup(adapter.as_ref()); + + loader.mock_next(1, "res_1".to_owned()); + loader.mock_next(2, "res_2".to_owned()); + + assert_eq!( + cache.get(1, adapter.get_extra(true)).await, + String::from("res_1") + ); + assert_eq!( + cache.peek(1, adapter.peek_extra()).await, + Some(String::from("res_1")) + ); + assert_eq!( + cache.get(2, adapter.get_extra(false)).await, + String::from("res_2") + ); + assert_eq!( + cache.peek(2, adapter.peek_extra()).await, + Some(String::from("res_2")) + ); +} + +async fn test_linear_memory(adapter: Arc) +where + T: TestAdapter, +{ + let (cache, loader) = setup(adapter.as_ref()); + + loader.mock_next(1, "res_1".to_owned()); + loader.mock_next(2, "res_2".to_owned()); + + assert_eq!(cache.peek_with_status(1, adapter.peek_extra()).await, None,); + assert_eq!( + cache.get_with_status(1, adapter.get_extra(true)).await, + (String::from("res_1"), CacheGetStatus::Miss), + ); + assert_eq!( + cache.get_with_status(1, adapter.get_extra(false)).await, + (String::from("res_1"), CacheGetStatus::Hit), + ); + assert_eq!( + cache.peek_with_status(1, adapter.peek_extra()).await, + Some((String::from("res_1"), CachePeekStatus::Hit)), + ); + assert_eq!( + cache.get_with_status(2, adapter.get_extra(false)).await, + (String::from("res_2"), CacheGetStatus::Miss), + ); + assert_eq!( + cache.get_with_status(2, adapter.get_extra(false)).await, + (String::from("res_2"), CacheGetStatus::Hit), + ); + assert_eq!( + cache.get_with_status(1, adapter.get_extra(true)).await, + (String::from("res_1"), CacheGetStatus::Hit), + ); + assert_eq!( + cache.peek_with_status(1, adapter.peek_extra()).await, + Some((String::from("res_1"), CachePeekStatus::Hit)), + ); + + assert_eq!(loader.loaded(), vec![(1, true), (2, false)]); +} + +async fn test_concurrent_query_loads_once(adapter: Arc) +where + T: TestAdapter, +{ + let (cache, loader) = setup(adapter.as_ref()); + + loader.block_global(); + + let adapter_captured = Arc::clone(&adapter); + let cache_captured = Arc::clone(&cache); + let barrier_pending_1 = Arc::new(Barrier::new(2)); + let barrier_pending_1_captured = Arc::clone(&barrier_pending_1); + let handle_1 = tokio::spawn(async move { + cache_captured + .get_with_status(1, adapter_captured.get_extra(true)) + .ensure_pending(barrier_pending_1_captured) + .await + }); + + barrier_pending_1.wait().await; + + let barrier_pending_2 = Arc::new(Barrier::new(3)); + + let adapter_captured = Arc::clone(&adapter); + let cache_captured = Arc::clone(&cache); + let barrier_pending_2_captured = Arc::clone(&barrier_pending_2); + let handle_2 = tokio::spawn(async move { + // use a different `extra` here to proof that the first one was used + cache_captured + .get_with_status(1, adapter_captured.get_extra(false)) + .ensure_pending(barrier_pending_2_captured) + .await + }); + let barrier_pending_2_captured = Arc::clone(&barrier_pending_2); + let handle_3 = tokio::spawn(async move { + // use a different `extra` here to proof that the first one was used + cache + .peek_with_status(1, adapter.peek_extra()) + .ensure_pending(barrier_pending_2_captured) + .await + }); + + barrier_pending_2.wait().await; + loader.mock_next(1, "res_1".to_owned()); + // Shouldn't issue concurrent load requests for the same key + let n_blocked = loader.unblock_global(); + assert_eq!(n_blocked, 1); + + assert_eq!( + handle_1.await.unwrap(), + (String::from("res_1"), CacheGetStatus::Miss), + ); + assert_eq!( + handle_2.await.unwrap(), + (String::from("res_1"), CacheGetStatus::MissAlreadyLoading), + ); + assert_eq!( + handle_3.await.unwrap(), + Some((String::from("res_1"), CachePeekStatus::MissAlreadyLoading)), + ); + + assert_eq!(loader.loaded(), vec![(1, true)]); +} + +async fn test_queries_are_parallelized(adapter: Arc) +where + T: TestAdapter, +{ + let (cache, loader) = setup(adapter.as_ref()); + + loader.block_global(); + + let barrier = Arc::new(Barrier::new(4)); + + let adapter_captured = Arc::clone(&adapter); + let cache_captured = Arc::clone(&cache); + let barrier_captured = Arc::clone(&barrier); + let handle_1 = tokio::spawn(async move { + cache_captured + .get(1, adapter_captured.get_extra(true)) + .ensure_pending(barrier_captured) + .await + }); + + let adapter_captured = Arc::clone(&adapter); + let cache_captured = Arc::clone(&cache); + let barrier_captured = Arc::clone(&barrier); + let handle_2 = tokio::spawn(async move { + cache_captured + .get(1, adapter_captured.get_extra(true)) + .ensure_pending(barrier_captured) + .await + }); + + let barrier_captured = Arc::clone(&barrier); + let handle_3 = tokio::spawn(async move { + cache + .get(2, adapter.get_extra(false)) + .ensure_pending(barrier_captured) + .await + }); + + barrier.wait().await; + + loader.mock_next(1, "res_1".to_owned()); + loader.mock_next(2, "res_2".to_owned()); + + let n_blocked = loader.unblock_global(); + assert_eq!(n_blocked, 2); + + assert_eq!(handle_1.await.unwrap(), String::from("res_1")); + assert_eq!(handle_2.await.unwrap(), String::from("res_1")); + assert_eq!(handle_3.await.unwrap(), String::from("res_2")); + + assert_eq!(loader.loaded(), vec![(1, true), (2, false)]); +} + +async fn test_cancel_request(adapter: Arc) +where + T: TestAdapter, +{ + let (cache, loader) = setup(adapter.as_ref()); + + loader.block_global(); + + let barrier_pending_1 = Arc::new(Barrier::new(2)); + let barrier_pending_1_captured = Arc::clone(&barrier_pending_1); + let adapter_captured = Arc::clone(&adapter); + let cache_captured = Arc::clone(&cache); + let handle_1 = tokio::spawn(async move { + cache_captured + .get(1, adapter_captured.get_extra(true)) + .ensure_pending(barrier_pending_1_captured) + .await + }); + + barrier_pending_1.wait().await; + let barrier_pending_2 = Arc::new(Barrier::new(2)); + let barrier_pending_2_captured = Arc::clone(&barrier_pending_2); + let handle_2 = tokio::spawn(async move { + cache + .get(1, adapter.get_extra(false)) + .ensure_pending(barrier_pending_2_captured) + .await + }); + + barrier_pending_2.wait().await; + + // abort first handle + handle_1.abort_and_wait().await; + + loader.mock_next(1, "res_1".to_owned()); + + let n_blocked = loader.unblock_global(); + assert_eq!(n_blocked, 1); + + assert_eq!(handle_2.await.unwrap(), String::from("res_1")); + + assert_eq!(loader.loaded(), vec![(1, true)]); +} + +async fn test_panic_request(adapter: Arc) +where + T: TestAdapter, +{ + let (cache, loader) = setup(adapter.as_ref()); + + loader.block_global(); + + // set up initial panicking request + let barrier_pending_get_panic = Arc::new(Barrier::new(2)); + let barrier_pending_get_panic_captured = Arc::clone(&barrier_pending_get_panic); + let adapter_captured = Arc::clone(&adapter); + let cache_captured = Arc::clone(&cache); + let handle_get_panic = tokio::spawn(async move { + cache_captured + .get(1, adapter_captured.get_extra(true)) + .ensure_pending(barrier_pending_get_panic_captured) + .await + }); + + barrier_pending_get_panic.wait().await; + + // set up other requests + let barrier_pending_others = Arc::new(Barrier::new(4)); + + let barrier_pending_others_captured = Arc::clone(&barrier_pending_others); + let adapter_captured = Arc::clone(&adapter); + let cache_captured = Arc::clone(&cache); + let handle_get_while_loading_panic = tokio::spawn(async move { + cache_captured + .get(1, adapter_captured.get_extra(false)) + .ensure_pending(barrier_pending_others_captured) + .await + }); + + let barrier_pending_others_captured = Arc::clone(&barrier_pending_others); + let adapter_captured = Arc::clone(&adapter); + let cache_captured = Arc::clone(&cache); + let handle_peek_while_loading_panic = tokio::spawn(async move { + cache_captured + .peek(1, adapter_captured.peek_extra()) + .ensure_pending(barrier_pending_others_captured) + .await + }); + + let barrier_pending_others_captured = Arc::clone(&barrier_pending_others); + let adapter_captured = Arc::clone(&adapter); + let cache_captured = Arc::clone(&cache); + let handle_get_other_key = tokio::spawn(async move { + cache_captured + .get(2, adapter_captured.get_extra(false)) + .ensure_pending(barrier_pending_others_captured) + .await + }); + + barrier_pending_others.wait().await; + + loader.panic_next(1); + loader.mock_next(1, "res_1".to_owned()); + loader.mock_next(2, "res_2".to_owned()); + + let n_blocked = loader.unblock_global(); + assert_eq!(n_blocked, 2); + + // panic of initial request + handle_get_panic.await.unwrap_err(); + + // requests that use the same loading status also panic + handle_get_while_loading_panic.await.unwrap_err(); + handle_peek_while_loading_panic.await.unwrap_err(); + + // unrelated request should succeed + assert_eq!(handle_get_other_key.await.unwrap(), String::from("res_2")); + + // failing key was tried exactly once (and the other unrelated key as well) + assert_eq!(loader.loaded(), vec![(1, true), (2, false)]); + + // loading after panic just works (no poisoning) + assert_eq!( + cache.get(1, adapter.get_extra(false)).await, + String::from("res_1") + ); + assert_eq!(loader.loaded(), vec![(1, true), (2, false), (1, false)]); +} + +async fn test_drop_cancels_loader(adapter: Arc) +where + T: TestAdapter, +{ + let (cache, loader) = setup(adapter.as_ref()); + + loader.block_global(); + + let barrier_pending = Arc::new(Barrier::new(2)); + let barrier_pending_captured = Arc::clone(&barrier_pending); + let handle = tokio::spawn(async move { + cache + .get(1, adapter.get_extra(true)) + .ensure_pending(barrier_pending_captured) + .await + }); + + barrier_pending.wait().await; + + handle.abort_and_wait().await; + + assert_eq!(Arc::strong_count(&loader), 1); +} + +async fn test_set_before_request(adapter: Arc) +where + T: TestAdapter, +{ + let (cache, loader) = setup(adapter.as_ref()); + + loader.block_global(); + + cache.set(1, String::from("foo")).await; + + // blocked loader is not used + let res = tokio::time::timeout( + Duration::from_millis(10), + cache.get(1, adapter.get_extra(false)), + ) + .await + .unwrap(); + assert_eq!(res, String::from("foo")); + assert_eq!(loader.loaded(), Vec::<(u8, bool)>::new()); +} + +async fn test_set_during_request(adapter: Arc) +where + T: TestAdapter, +{ + let (cache, loader) = setup(adapter.as_ref()); + + loader.block_global(); + + let adapter_captured = Arc::clone(&adapter); + let cache_captured = Arc::clone(&cache); + let barrier_pending = Arc::new(Barrier::new(2)); + let barrier_pending_captured = Arc::clone(&barrier_pending); + let handle = tokio::spawn(async move { + cache_captured + .get(1, adapter_captured.get_extra(true)) + .ensure_pending(barrier_pending_captured) + .await + }); + barrier_pending.wait().await; + + cache.set(1, String::from("foo")).await; + + // request succeeds even though the loader is blocked + let res = tokio::time::timeout(Duration::from_millis(10), handle) + .await + .unwrap() + .unwrap(); + assert_eq!(res, String::from("foo")); + assert_eq!(loader.loaded(), vec![(1, true)]); + + // still cached + let res = tokio::time::timeout( + Duration::from_millis(10), + cache.get(1, adapter.get_extra(false)), + ) + .await + .unwrap(); + assert_eq!(res, String::from("foo")); + assert_eq!(loader.loaded(), vec![(1, true)]); +} diff --git a/cache_system/src/cancellation_safe_future.rs b/cache_system/src/cancellation_safe_future.rs new file mode 100644 index 0000000..ae45fc3 --- /dev/null +++ b/cache_system/src/cancellation_safe_future.rs @@ -0,0 +1,184 @@ +use std::{ + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use futures::future::BoxFuture; +use parking_lot::Mutex; +use tokio::task::JoinHandle; + +/// Receiver for [`CancellationSafeFuture`] join handles if the future was rescued from cancellation. +/// +/// `T` is the [output type](Future::Output) of the wrapped future. +#[derive(Debug, Default, Clone)] +pub struct CancellationSafeFutureReceiver { + inner: Arc>, +} + +#[derive(Debug, Default)] +struct ReceiverInner { + slot: Mutex>>, +} + +impl Drop for ReceiverInner { + fn drop(&mut self) { + let handle = self.slot.lock(); + if let Some(handle) = handle.as_ref() { + handle.abort(); + } + } +} + +/// Wrapper around a future that cannot be cancelled. +/// +/// When the future is dropped/cancelled, we'll spawn a tokio task to _rescue_ it. +pub struct CancellationSafeFuture +where + F: Future + Send + 'static, + F::Output: Send, +{ + /// Mark if the inner future finished. If not, we must spawn a helper task on drop. + done: bool, + + /// Inner future. + /// + /// Wrapped in an `Option` so we can extract it during drop. Inside that option however we also need a pinned + /// box because once this wrapper is polled, it will be pinned in memory -- even during drop. Now the inner + /// future does not necessarily implement `Unpin`, so we need a heap allocation to pin it in memory even when we + /// move it out of this option. + inner: Option>, + + /// Where to store the join handle on drop. + receiver: CancellationSafeFutureReceiver, +} + +impl Drop for CancellationSafeFuture +where + F: Future + Send + 'static, + F::Output: Send, +{ + fn drop(&mut self) { + if !self.done { + // acquire lock BEFORE checking the Arc + let mut receiver = self.receiver.inner.slot.lock(); + assert!(receiver.is_none()); + + // The Mutex is owned by the Arc and cannot be moved out of it. So after we acquired the lock we can safely + // check if any external party still has access to the receiver state. If not, we assume there is no + // interest in this future at all (e.g. during shutdown) and will NOT spawn it. + if Arc::strong_count(&self.receiver.inner) > 1 { + let inner = self.inner.take().expect("Double-drop?"); + let handle = tokio::task::spawn(inner); + *receiver = Some(handle); + } + } + } +} + +impl CancellationSafeFuture +where + F: Future + Send, + F::Output: Send, +{ + /// Create new future that is protected from cancellation. + /// + /// If [`CancellationSafeFuture`] is cancelled (i.e. dropped) and there is still some external receiver of the state + /// left, than we will drive the payload (`f`) to completion. Otherwise `f` will be cancelled. + pub fn new(fut: F, receiver: CancellationSafeFutureReceiver) -> Self { + Self { + done: false, + inner: Some(Box::pin(fut)), + receiver, + } + } +} + +impl Future for CancellationSafeFuture +where + F: Future + Send, + F::Output: Send, +{ + type Output = F::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + assert!(!self.done, "Polling future that already returned"); + + match self.inner.as_mut().expect("not dropped").as_mut().poll(cx) { + Poll::Ready(res) => { + self.done = true; + Poll::Ready(res) + } + Poll::Pending => Poll::Pending, + } + } +} + +#[cfg(test)] +mod tests { + use std::{ + sync::atomic::{AtomicBool, Ordering}, + time::Duration, + }; + + use tokio::sync::Barrier; + + use super::*; + + #[tokio::test] + async fn test_happy_path() { + let done = Arc::new(AtomicBool::new(false)); + let done_captured = Arc::clone(&done); + + let receiver = Default::default(); + let fut = CancellationSafeFuture::new( + async move { + done_captured.store(true, Ordering::SeqCst); + }, + receiver, + ); + + fut.await; + + assert!(done.load(Ordering::SeqCst)); + } + + #[tokio::test] + async fn test_cancel_future() { + let done = Arc::new(Barrier::new(2)); + let done_captured = Arc::clone(&done); + + let receiver = CancellationSafeFutureReceiver::default(); + let fut = CancellationSafeFuture::new( + async move { + done_captured.wait().await; + }, + receiver.clone(), + ); + + drop(fut); + + tokio::time::timeout(Duration::from_secs(5), done.wait()) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_receiver_gone() { + let done = Arc::new(Barrier::new(2)); + let done_captured = Arc::clone(&done); + + let receiver = Default::default(); + let fut = CancellationSafeFuture::new( + async move { + done_captured.wait().await; + }, + receiver, + ); + + drop(fut); + + assert_eq!(Arc::strong_count(&done), 1); + } +} diff --git a/cache_system/src/lib.rs b/cache_system/src/lib.rs new file mode 100644 index 0000000..68e60ae --- /dev/null +++ b/cache_system/src/lib.rs @@ -0,0 +1,29 @@ +//! Flexible and modular cache system. +#![deny(rustdoc::broken_intra_doc_links, rust_2018_idioms)] +#![warn( + missing_copy_implementations, + missing_docs, + clippy::explicit_iter_loop, + // See https://github.com/influxdata/influxdb_iox/pull/1671 + clippy::future_not_send, + clippy::use_self, + clippy::clone_on_ref_ptr, + clippy::todo, + clippy::dbg_macro, + unused_crate_dependencies +)] +#![allow(unreachable_pub)] + +// Workaround for "unused crate" lint false positives. +#[cfg(test)] +use criterion as _; +use workspace_hack as _; + +pub mod addressable_heap; +pub mod backend; +pub mod cache; +mod cancellation_safe_future; +pub mod loader; +pub mod resource_consumption; +#[cfg(test)] +mod test_util; diff --git a/cache_system/src/loader/batch.rs b/cache_system/src/loader/batch.rs new file mode 100644 index 0000000..36ab123 --- /dev/null +++ b/cache_system/src/loader/batch.rs @@ -0,0 +1,501 @@ +//! Batching of loader request. +use std::{ + collections::HashMap, + fmt::Debug, + future::Future, + hash::Hash, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + task::Poll, +}; + +use async_trait::async_trait; +use futures::{ + channel::oneshot::{channel, Sender}, + FutureExt, +}; +use observability_deps::tracing::trace; +use parking_lot::Mutex; + +use crate::cancellation_safe_future::{CancellationSafeFuture, CancellationSafeFutureReceiver}; + +use super::Loader; + +/// Batch [load](Loader::load) requests. +/// +/// Requests against this loader will be [pending](std::task::Poll::Pending) until [flush](BatchLoaderFlusher::flush) is +/// called. To simplify the usage -- esp. in combination with [`Cache::get`] -- use [`BatchLoaderFlusherExt`]. +/// +/// +/// [`Cache::get`]: crate::cache::Cache::get +#[derive(Debug)] +pub struct BatchLoader +where + K: Debug + Hash + Send + 'static, + Extra: Debug + Send + 'static, + V: Debug + Send + 'static, + L: Loader, Extra = Vec, V = Vec>, +{ + inner: Arc>, +} + +impl BatchLoader +where + K: Debug + Hash + Send + 'static, + Extra: Debug + Send + 'static, + V: Debug + Send + 'static, + L: Loader, Extra = Vec, V = Vec>, +{ + /// Create new batch loader based on a non-batched, vector-based one. + pub fn new(inner: L) -> Self { + Self { + inner: Arc::new(BatchLoaderInner { + inner, + pending: Default::default(), + job_id_counter: Default::default(), + job_handles: Default::default(), + }), + } + } +} + +/// State of [`BatchLoader`]. +/// +/// This is an extra struct so it can be wrapped into an [`Arc`] and shared with the futures that are spawned into +/// [`CancellationSafeFuture`] +#[derive(Debug)] +struct BatchLoaderInner +where + K: Debug + Hash + Send + 'static, + Extra: Debug + Send + 'static, + V: Debug + Send + 'static, + L: Loader, Extra = Vec, V = Vec>, +{ + inner: L, + pending: Mutex)>>, + job_id_counter: AtomicU64, + job_handles: Mutex>>, +} + +/// Flush interface for [`BatchLoader`]. +/// +/// This is a trait so you can [type-erase](https://en.wikipedia.org/wiki/Type_erasure) it by putting it into an +/// [`Arc`], +/// +/// This trait is object-safe. +#[async_trait] +pub trait BatchLoaderFlusher: Debug + Send + Sync + 'static { + /// Flush all batched requests. + async fn flush(&self); +} + +#[async_trait] +impl BatchLoaderFlusher for Arc { + async fn flush(&self) { + self.as_ref().flush().await; + } +} + +#[async_trait] +impl BatchLoaderFlusher for BatchLoader +where + K: Debug + Hash + Send + 'static, + Extra: Debug + Send + 'static, + V: Debug + Send + 'static, + L: Loader, Extra = Vec, V = Vec>, +{ + async fn flush(&self) { + let pending: Vec<_> = { + let mut pending = self.inner.pending.lock(); + std::mem::take(pending.as_mut()) + }; + + if pending.is_empty() { + return; + } + trace!(n_pending = pending.len(), "flush batch loader",); + + let job_id = self.inner.job_id_counter.fetch_add(1, Ordering::SeqCst); + let handle_recv = CancellationSafeFutureReceiver::default(); + + { + let mut job_handles = self.inner.job_handles.lock(); + job_handles.insert(job_id, handle_recv.clone()); + } + + let inner = Arc::clone(&self.inner); + let fut = CancellationSafeFuture::new( + async move { + let mut keys = Vec::with_capacity(pending.len()); + let mut extras = Vec::with_capacity(pending.len()); + let mut senders = Vec::with_capacity(pending.len()); + + for (k, extra, sender) in pending { + keys.push(k); + extras.push(extra); + senders.push(sender); + } + + let values = inner.inner.load(keys, extras).await; + assert_eq!(values.len(), senders.len()); + + for (value, sender) in values.into_iter().zip(senders) { + sender.send(value).unwrap(); + } + + let mut job_handles = inner.job_handles.lock(); + job_handles.remove(&job_id); + }, + handle_recv, + ); + fut.await; + } +} + +#[async_trait] +impl Loader for BatchLoader +where + K: Debug + Hash + Send + 'static, + Extra: Debug + Send + 'static, + V: Debug + Send + 'static, + L: Loader, Extra = Vec, V = Vec>, +{ + type K = K; + type Extra = Extra; + type V = V; + + async fn load(&self, k: Self::K, extra: Self::Extra) -> Self::V { + let (tx, rx) = channel(); + + { + let mut pending = self.inner.pending.lock(); + pending.push((k, extra, tx)); + } + + rx.await.unwrap() + } +} + +/// Extension trait for [`BatchLoaderFlusher`] because the methods on this extension trait are not object safe. +#[async_trait] +pub trait BatchLoaderFlusherExt { + /// Try to poll all given futures and automatically [flush](BatchLoaderFlusher) if any of them end up in a pending state. + /// + /// This guarantees that the order of the results is identical to the order of the futures. + async fn auto_flush(&self, futures: Vec) -> Vec + where + F: Future + Send, + F::Output: Send; +} + +#[async_trait] +impl BatchLoaderFlusherExt for B +where + B: BatchLoaderFlusher, +{ + async fn auto_flush(&self, futures: Vec) -> Vec + where + F: Future + Send, + F::Output: Send, + { + let mut futures = futures + .into_iter() + .map(|f| f.boxed()) + .enumerate() + .collect::>(); + let mut output: Vec> = (0..futures.len()).map(|_| None).collect(); + + while !futures.is_empty() { + let mut pending = Vec::with_capacity(futures.len()); + + for (idx, mut f) in futures.into_iter() { + match futures::poll!(&mut f) { + Poll::Ready(res) => { + output[idx] = Some(res); + } + Poll::Pending => { + pending.push((idx, f)); + } + } + } + + if !pending.is_empty() { + self.flush().await; + + // prevent hot-looping: + // It seems that in some cases the underlying loader is ready but the data is not available via the + // cache driver yet. This is likely due to the signalling system within the cache driver that prevents + // cancelation, but also allows side-loading and at the same time prevents that the same key is loaded + // multiple times. Tokio doesn't know that this method here is basically a wait loop. So we yield back + // to the tokio worker and to allow it to make some progress. Since flush+load take some time anyways, + // this yield here is not overall performance critical. + tokio::task::yield_now().await; + } + + futures = pending; + } + + output + .into_iter() + .map(|o| o.expect("all futures finished")) + .collect() + } +} + +#[cfg(test)] +mod tests { + use tokio::sync::Barrier; + + use crate::{ + cache::{driver::CacheDriver, Cache}, + loader::test_util::TestLoader, + test_util::EnsurePendingExt, + }; + + use super::*; + + type TestLoaderT = Arc, Vec, Vec>>; + + #[tokio::test] + async fn test_flush_empty() { + let (inner, batch) = setup(); + batch.flush().await; + assert_eq!(inner.loaded(), vec![],); + } + + #[tokio::test] + async fn test_flush_manual() { + let (inner, batch) = setup(); + + let pending_barrier_1 = Arc::new(Barrier::new(2)); + let pending_barrier_1_captured = Arc::clone(&pending_barrier_1); + let batch_captured = Arc::clone(&batch); + let handle_1 = tokio::spawn(async move { + batch_captured + .load(1, true) + .ensure_pending(pending_barrier_1_captured) + .await + }); + pending_barrier_1.wait().await; + + let pending_barrier_2 = Arc::new(Barrier::new(2)); + let pending_barrier_2_captured = Arc::clone(&pending_barrier_2); + let batch_captured = Arc::clone(&batch); + let handle_2 = tokio::spawn(async move { + batch_captured + .load(2, false) + .ensure_pending(pending_barrier_2_captured) + .await + }); + pending_barrier_2.wait().await; + + inner.mock_next(vec![1, 2], vec![String::from("foo"), String::from("bar")]); + + batch.flush().await; + assert_eq!(inner.loaded(), vec![(vec![1, 2], vec![true, false])],); + + assert_eq!(handle_1.await.unwrap(), String::from("foo")); + assert_eq!(handle_2.await.unwrap(), String::from("bar")); + } + + /// Simulate the following scenario: + /// + /// 1. load `1`, flush it, inner load starts processing `[1]` + /// 2. load `2`, flush it, inner load starts processing `[2]` + /// 3. inner loader returns result for `[2]`, batch loader returns that result as well + /// 4. inner loader returns result for `[1]`, batch loader returns that result as well + #[tokio::test] + async fn test_concurrent_load() { + let (inner, batch) = setup(); + + let load_barrier_1 = inner.block_next(vec![1], vec![String::from("foo")]); + inner.mock_next(vec![2], vec![String::from("bar")]); + + // set up first load + let pending_barrier_1 = Arc::new(Barrier::new(2)); + let pending_barrier_1_captured = Arc::clone(&pending_barrier_1); + let batch_captured = Arc::clone(&batch); + let handle_1 = tokio::spawn(async move { + batch_captured + .load(1, true) + .ensure_pending(pending_barrier_1_captured) + .await + }); + pending_barrier_1.wait().await; + + // flush first load, this is blocked by the load barrier + let pending_barrier_2 = Arc::new(Barrier::new(2)); + let pending_barrier_2_captured = Arc::clone(&pending_barrier_2); + let batch_captured = Arc::clone(&batch); + let handle_2 = tokio::spawn(async move { + batch_captured + .flush() + .ensure_pending(pending_barrier_2_captured) + .await; + }); + pending_barrier_2.wait().await; + + // set up second load + let pending_barrier_3 = Arc::new(Barrier::new(2)); + let pending_barrier_3_captured = Arc::clone(&pending_barrier_3); + let batch_captured = Arc::clone(&batch); + let handle_3 = tokio::spawn(async move { + batch_captured + .load(2, false) + .ensure_pending(pending_barrier_3_captured) + .await + }); + pending_barrier_3.wait().await; + + // flush 2nd load and get result + batch.flush().await; + assert_eq!(handle_3.await.unwrap(), String::from("bar")); + + // flush 1st load and get result + load_barrier_1.wait().await; + handle_2.await.unwrap(); + assert_eq!(handle_1.await.unwrap(), String::from("foo")); + + assert_eq!( + inner.loaded(), + vec![(vec![1], vec![true]), (vec![2], vec![false])], + ); + } + + #[tokio::test] + async fn test_cancel_flush() { + let (inner, batch) = setup(); + + let load_barrier_1 = inner.block_next(vec![1], vec![String::from("foo")]); + + // set up load + let pending_barrier_1 = Arc::new(Barrier::new(2)); + let pending_barrier_1_captured = Arc::clone(&pending_barrier_1); + let batch_captured = Arc::clone(&batch); + let handle_1 = tokio::spawn(async move { + batch_captured + .load(1, true) + .ensure_pending(pending_barrier_1_captured) + .await + }); + pending_barrier_1.wait().await; + + // flush load, this is blocked by the load barrier + let pending_barrier_2 = Arc::new(Barrier::new(2)); + let pending_barrier_2_captured = Arc::clone(&pending_barrier_2); + let batch_captured = Arc::clone(&batch); + let handle_2 = tokio::spawn(async move { + batch_captured + .flush() + .ensure_pending(pending_barrier_2_captured) + .await; + }); + pending_barrier_2.wait().await; + + // abort flush + handle_2.abort(); + + // flush load and get result + load_barrier_1.wait().await; + assert_eq!(handle_1.await.unwrap(), String::from("foo")); + + assert_eq!(inner.loaded(), vec![(vec![1], vec![true])],); + } + + #[tokio::test] + async fn test_cancel_load_and_flush() { + let (inner, batch) = setup(); + + let load_barrier_1 = inner.block_next(vec![1], vec![String::from("foo")]); + + // set up load + let pending_barrier_1 = Arc::new(Barrier::new(2)); + let pending_barrier_1_captured = Arc::clone(&pending_barrier_1); + let batch_captured = Arc::clone(&batch); + let handle_1 = tokio::spawn(async move { + batch_captured + .load(1, true) + .ensure_pending(pending_barrier_1_captured) + .await + }); + pending_barrier_1.wait().await; + + // flush load, this is blocked by the load barrier + let pending_barrier_2 = Arc::new(Barrier::new(2)); + let pending_barrier_2_captured = Arc::clone(&pending_barrier_2); + let batch_captured = Arc::clone(&batch); + let handle_2 = tokio::spawn(async move { + batch_captured + .flush() + .ensure_pending(pending_barrier_2_captured) + .await; + }); + pending_barrier_2.wait().await; + + // abort load and flush + handle_1.abort(); + handle_2.abort(); + + // unblock + load_barrier_1.wait().await; + + // load was still driven to completion + assert_eq!(inner.loaded(), vec![(vec![1], vec![true])],); + } + + #[tokio::test] + async fn test_auto_flush_with_loader() { + let (inner, batch) = setup(); + + inner.mock_next(vec![1, 2], vec![String::from("foo"), String::from("bar")]); + + assert_eq!( + batch + .auto_flush(vec![batch.load(1, true), batch.load(2, false)]) + .await, + vec![String::from("foo"), String::from("bar")], + ); + + assert_eq!(inner.loaded(), vec![(vec![1, 2], vec![true, false])],); + } + + #[tokio::test] + async fn test_auto_flush_integration_with_cache_driver() { + let (inner, batch) = setup(); + let cache = CacheDriver::new( + Arc::clone(&batch), + HashMap::new(), + &metric::Registry::default(), + "test", + ); + + inner.mock_next(vec![1, 2], vec![String::from("foo"), String::from("bar")]); + inner.mock_next(vec![3], vec![String::from("baz")]); + + assert_eq!( + batch + .auto_flush(vec![cache.get(1, true), cache.get(2, false)]) + .await, + vec![String::from("foo"), String::from("bar")], + ); + assert_eq!( + batch + .auto_flush(vec![cache.get(2, true), cache.get(3, true)]) + .await, + vec![String::from("bar"), String::from("baz")], + ); + + assert_eq!( + inner.loaded(), + vec![(vec![1, 2], vec![true, false]), (vec![3], vec![true])], + ); + } + + fn setup() -> (TestLoaderT, Arc>) { + let inner = TestLoaderT::default(); + let batch = Arc::new(BatchLoader::new(Arc::clone(&inner))); + (inner, batch) + } +} diff --git a/cache_system/src/loader/metrics.rs b/cache_system/src/loader/metrics.rs new file mode 100644 index 0000000..72645b2 --- /dev/null +++ b/cache_system/src/loader/metrics.rs @@ -0,0 +1,247 @@ +//! Metrics for [`Loader`]. + +use std::sync::Arc; + +use async_trait::async_trait; +use iox_time::TimeProvider; +use metric::{DurationHistogram, U64Counter}; +use observability_deps::tracing::warn; +use parking_lot::Mutex; +use pdatastructs::filters::{bloomfilter::BloomFilter, Filter}; + +use super::Loader; + +/// Wraps a [`Loader`] and adds metrics. +pub struct MetricsLoader +where + L: Loader, +{ + inner: L, + time_provider: Arc, + metric_calls_new: U64Counter, + metric_calls_probably_reloaded: U64Counter, + metric_duration: DurationHistogram, + seen: Mutex>, +} + +impl MetricsLoader +where + L: Loader, +{ + /// Create new wrapper. + /// + /// # Testing + /// If `testing` is set, the "seen" metrics will NOT be processed correctly because the underlying data structure is + /// too expensive to create many times a second in an un-optimized debug build. + pub fn new( + inner: L, + name: &'static str, + time_provider: Arc, + metric_registry: &metric::Registry, + testing: bool, + ) -> Self { + let metric_calls = metric_registry.register_metric::( + "cache_load_function_calls", + "Count how often a cache loader was called.", + ); + let metric_calls_new = metric_calls.recorder(&[("name", name), ("status", "new")]); + let metric_calls_probably_reloaded = + metric_calls.recorder(&[("name", name), ("status", "probably_reloaded")]); + let metric_duration = metric_registry + .register_metric::( + "cache_load_function_duration", + "Time taken by cache load function calls", + ) + .recorder(&[("name", name)]); + + let seen = if testing { + BloomFilter::with_params(1, 1) + } else { + // Set up bloom filter for "probably reloaded" test: + // + // - input size: we expect 10M elements + // - reliability: probability of false positives should be <= 1% + // - CPU efficiency: number of hash functions should be <= 10 + // - RAM efficiency: size should be <= 15MB + // + // + // A bloom filter was chosen here because of the following properties: + // + // - memory bound: The storage size is bound even when the set of "probably reloaded" entries approaches + // infinite sizes. + // - memory efficiency: We do not need to store the actual keys. + // - infallible: Inserting new data into the filter never fails (in contrast to for example a CuckooFilter or + // QuotientFilter). + // + // The fact that a filter can produce false positives (i.e. it classifies an actual new entry as "probably + // reloaded") is considered to be OK since the metric is more of an estimate and a guide for cache tuning. We + // might want to use a more efficient (i.e. more modern) filter design at one point though. + let seen = BloomFilter::with_properties(10_000_000, 1.0 / 100.0); + const BOUND_HASH_FUNCTIONS: usize = 10; + assert!( + seen.k() <= BOUND_HASH_FUNCTIONS, + "number of hash functions for bloom filter should be <= {} but is {}", + BOUND_HASH_FUNCTIONS, + seen.k(), + ); + const BOUND_SIZE_BYTES: usize = 15_000_000; + let size_bytes = (seen.m() + 7) / 8; + assert!( + size_bytes <= BOUND_SIZE_BYTES, + "size of bloom filter should be <= {BOUND_SIZE_BYTES} bytes but is {size_bytes} bytes", + ); + + seen + }; + + Self { + inner, + time_provider, + metric_calls_new, + metric_calls_probably_reloaded, + metric_duration, + seen: Mutex::new(seen), + } + } +} + +impl std::fmt::Debug for MetricsLoader +where + L: Loader, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MetricsLoader").finish_non_exhaustive() + } +} + +#[async_trait] +impl Loader for MetricsLoader +where + L: Loader, +{ + type K = L::K; + type V = L::V; + type Extra = L::Extra; + + async fn load(&self, k: Self::K, extra: Self::Extra) -> Self::V { + { + let mut seen_guard = self.seen.lock(); + + if seen_guard.insert(&k).expect("bloom filter cannot fail") { + &self.metric_calls_new + } else { + &self.metric_calls_probably_reloaded + } + .inc(1); + } + + let t_start = self.time_provider.now(); + let v = self.inner.load(k, extra).await; + let t_end = self.time_provider.now(); + + match t_end.checked_duration_since(t_start) { + Some(duration) => { + self.metric_duration.record(duration); + } + None => { + warn!("Clock went backwards, not recording loader duration"); + } + } + + v + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use iox_time::{MockProvider, Time}; + use metric::{Observation, RawReporter}; + + use crate::loader::FunctionLoader; + + use super::*; + + #[tokio::test] + async fn test_metrics() { + let time_provider = Arc::new(MockProvider::new(Time::from_timestamp_millis(0).unwrap())); + let metric_registry = Arc::new(metric::Registry::new()); + + let time_provider_captured = Arc::clone(&time_provider); + let d = Duration::from_secs(10); + let inner_loader = FunctionLoader::new(move |x: u64, _extra: ()| { + let time_provider_captured = Arc::clone(&time_provider_captured); + async move { + time_provider_captured.inc(d); + x.to_string() + } + }); + + let loader = MetricsLoader::new( + inner_loader, + "my_loader", + time_provider, + &metric_registry, + false, + ); + + let mut reporter = RawReporter::default(); + metric_registry.report(&mut reporter); + for status in ["new", "probably_reloaded"] { + assert_eq!( + reporter + .metric("cache_load_function_calls") + .unwrap() + .observation(&[("name", "my_loader"), ("status", status)]) + .unwrap(), + &Observation::U64Counter(0) + ); + } + if let Observation::DurationHistogram(hist) = reporter + .metric("cache_load_function_duration") + .unwrap() + .observation(&[("name", "my_loader")]) + .unwrap() + { + assert_eq!(hist.sample_count(), 0); + assert_eq!(hist.total, Duration::from_secs(0)); + } else { + panic!("Wrong observation type"); + } + + assert_eq!(loader.load(42, ()).await, String::from("42")); + assert_eq!(loader.load(42, ()).await, String::from("42")); + assert_eq!(loader.load(1337, ()).await, String::from("1337")); + + let mut reporter = RawReporter::default(); + metric_registry.report(&mut reporter); + assert_eq!( + reporter + .metric("cache_load_function_calls") + .unwrap() + .observation(&[("name", "my_loader"), ("status", "new")]) + .unwrap(), + &Observation::U64Counter(2) + ); + assert_eq!( + reporter + .metric("cache_load_function_calls") + .unwrap() + .observation(&[("name", "my_loader"), ("status", "probably_reloaded")]) + .unwrap(), + &Observation::U64Counter(1) + ); + if let Observation::DurationHistogram(hist) = reporter + .metric("cache_load_function_duration") + .unwrap() + .observation(&[("name", "my_loader")]) + .unwrap() + { + assert_eq!(hist.sample_count(), 3); + assert_eq!(hist.total, 3 * d); + } else { + panic!("Wrong observation type"); + } + } +} diff --git a/cache_system/src/loader/mod.rs b/cache_system/src/loader/mod.rs new file mode 100644 index 0000000..6c429a7 --- /dev/null +++ b/cache_system/src/loader/mod.rs @@ -0,0 +1,151 @@ +//! How to load new cache entries. +use async_trait::async_trait; +use std::{fmt::Debug, future::Future, hash::Hash, marker::PhantomData, sync::Arc}; + +pub mod batch; +pub mod metrics; + +#[cfg(test)] +pub(crate) mod test_util; + +/// Loader for missing [`Cache`](crate::cache::Cache) entries. +#[async_trait] +pub trait Loader: std::fmt::Debug + Send + Sync + 'static { + /// Cache key. + type K: Debug + Hash + Send + 'static; + + /// Extra data needed when loading a missing entry. Specify `()` if not needed. + type Extra: Debug + Send + 'static; + + /// Cache value. + type V: Debug + Send + 'static; + + /// Load value for given key, using the extra data if needed. + async fn load(&self, k: Self::K, extra: Self::Extra) -> Self::V; +} + +#[async_trait] +impl Loader for Box> +where + K: Debug + Hash + Send + 'static, + V: Debug + Send + 'static, + Extra: Debug + Send + 'static, +{ + type K = K; + type V = V; + type Extra = Extra; + + async fn load(&self, k: Self::K, extra: Self::Extra) -> Self::V { + self.as_ref().load(k, extra).await + } +} + +#[async_trait] +impl Loader for Arc +where + K: Debug + Hash + Send + 'static, + V: Debug + Send + 'static, + Extra: Debug + Send + 'static, + L: Loader, +{ + type K = K; + type V = V; + type Extra = Extra; + + async fn load(&self, k: Self::K, extra: Self::Extra) -> Self::V { + self.as_ref().load(k, extra).await + } +} + +/// Simple-to-use wrapper for async functions to act as a [`Loader`]. +/// +/// # Typing +/// Semantically this wrapper has only one degree of freedom: `T`, which is the async loader function. However until +/// [`fn_traits`] are stable, there is no way to extract the parameters and return value from a function via associated +/// types. So we need to add additional type parametes for the special `Fn(...) -> ...` handling. +/// +/// It is likely that `T` will be a closure, e.g.: +/// +/// ``` +/// use cache_system::loader::FunctionLoader; +/// +/// let my_loader = FunctionLoader::new(|k: u8, _extra: ()| async move { +/// format!("{k}") +/// }); +/// ``` +/// +/// There is no way to spell out the exact type of `my_loader` in the above example, because the closure has an +/// anonymous type. If you need the type signature of [`FunctionLoader`], you have to +/// [erase the type](https://en.wikipedia.org/wiki/Type_erasure) by putting the [`FunctionLoader`] it into a [`Box`], +/// e.g.: +/// +/// ``` +/// use cache_system::loader::{Loader, FunctionLoader}; +/// +/// let my_loader = FunctionLoader::new(|k: u8, _extra: ()| async move { +/// format!("{k}") +/// }); +/// let m_loader: Box> = Box::new(my_loader); +/// ``` +/// +/// +/// [`fn_traits`]: https://doc.rust-lang.org/beta/unstable-book/library-features/fn-traits.html +pub struct FunctionLoader +where + T: Fn(K, Extra) -> F + Send + Sync + 'static, + F: Future + Send + 'static, + K: Debug + Send + 'static, + F::Output: Debug + Send + 'static, + Extra: Debug + Send + 'static, +{ + loader: T, + _phantom: PhantomData (F, K, Extra) + Send + Sync + 'static>, +} + +impl FunctionLoader +where + T: Fn(K, Extra) -> F + Send + Sync + 'static, + F: Future + Send + 'static, + K: Debug + Send + 'static, + F::Output: Debug + Send + 'static, + Extra: Debug + Send + 'static, +{ + /// Create loader from function. + pub fn new(loader: T) -> Self { + Self { + loader, + _phantom: PhantomData, + } + } +} + +impl std::fmt::Debug for FunctionLoader +where + T: Fn(K, Extra) -> F + Send + Sync + 'static, + F: Future + Send + 'static, + K: Debug + Send + 'static, + F::Output: Debug + Send + 'static, + Extra: Debug + Send + 'static, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FunctionLoader").finish_non_exhaustive() + } +} + +#[async_trait] +impl Loader for FunctionLoader +where + T: Fn(K, Extra) -> F + Send + Sync + 'static, + F: Future + Send + 'static, + K: Debug + Hash + Send + 'static, + F::Output: Debug + Send + 'static, + Extra: Debug + Send + 'static, +{ + type K = K; + type V = F::Output; + type Extra = Extra; + + async fn load(&self, k: Self::K, extra: Self::Extra) -> Self::V { + (self.loader)(k, extra).await + } +} diff --git a/cache_system/src/loader/test_util.rs b/cache_system/src/loader/test_util.rs new file mode 100644 index 0000000..a35e708 --- /dev/null +++ b/cache_system/src/loader/test_util.rs @@ -0,0 +1,239 @@ +use std::{collections::HashMap, fmt::Debug, hash::Hash, sync::Arc}; + +use async_trait::async_trait; +use parking_lot::Mutex; +use tokio::sync::{Barrier, Notify}; + +use super::Loader; + +#[derive(Debug)] +enum TestLoaderResponse { + Answer { v: V, block: Option> }, + Panic, +} + +/// An easy-to-mock [`Loader`]. +#[derive(Debug, Default)] +pub struct TestLoader +where + K: Clone + Debug + Eq + Hash + Send + 'static, + Extra: Clone + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + responses: Mutex>>>, + blocked: Mutex>>, + loaded: Mutex>, +} + +impl TestLoader +where + K: Clone + Debug + Eq + Hash + Send + 'static, + Extra: Clone + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + /// Mock next value for given key-value pair. + pub fn mock_next(&self, k: K, v: V) { + self.mock_inner(k, TestLoaderResponse::Answer { v, block: None }); + } + + /// Block on next load for given key-value pair. + /// + /// Return a barrier that can be used to unblock the load. + #[must_use] + pub fn block_next(&self, k: K, v: V) -> Arc { + let block = Arc::new(Barrier::new(2)); + self.mock_inner( + k, + TestLoaderResponse::Answer { + v, + block: Some(Arc::clone(&block)), + }, + ); + block + } + + /// Panic when loading value for `k`. + /// + /// If this is used together with [`block_global`](Self::block_global), the panic will occur AFTER + /// blocking. + pub fn panic_next(&self, k: K) { + self.mock_inner(k, TestLoaderResponse::Panic); + } + + fn mock_inner(&self, k: K, response: TestLoaderResponse) { + let mut responses = self.responses.lock(); + responses.entry(k).or_default().push(response); + } + + /// Block all [`load`](Self::load) requests until [`unblock`](Self::unblock) is called. + /// + /// If this is used together with [`panic_once`](Self::panic_once), the panic will occur + /// AFTER blocking. + pub fn block_global(&self) { + let mut blocked = self.blocked.lock(); + assert!(blocked.is_none()); + *blocked = Some(Arc::new(Notify::new())); + } + + /// Unblock all requests. + /// + /// Returns number of requests that were blocked. + pub fn unblock_global(&self) -> usize { + let handle = self.blocked.lock().take().unwrap(); + let blocked_count = Arc::strong_count(&handle) - 1; + handle.notify_waiters(); + blocked_count + } + + /// List all keys that were loaded. + /// + /// Contains duplicates if keys were loaded multiple times. + pub fn loaded(&self) -> Vec<(K, Extra)> { + self.loaded.lock().clone() + } +} + +impl Drop for TestLoader +where + K: Clone + Debug + Eq + Hash + Send + 'static, + Extra: Clone + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + fn drop(&mut self) { + // prevent double-panic (i.e. aborts) + if !std::thread::panicking() { + for entries in self.responses.lock().values() { + assert!(entries.is_empty(), "mocked response left"); + } + } + } +} + +#[async_trait] +impl Loader for TestLoader +where + K: Clone + Debug + Eq + Hash + Send + 'static, + Extra: Clone + Debug + Send + 'static, + V: Clone + Debug + Send + 'static, +{ + type K = K; + type Extra = Extra; + type V = V; + + async fn load(&self, k: Self::K, extra: Self::Extra) -> Self::V { + self.loaded.lock().push((k.clone(), extra)); + + // need to capture the cloned notify handle, otherwise the lock guard leaks into the + // generator + let maybe_block = self.blocked.lock().clone(); + if let Some(block) = maybe_block { + block.notified().await; + } + + let response = { + let mut guard = self.responses.lock(); + let entries = guard.get_mut(&k).expect("entry not mocked"); + + assert!(!entries.is_empty(), "no mocked response left"); + + entries.remove(0) + }; + + match response { + TestLoaderResponse::Answer { v, block } => { + if let Some(block) = block { + block.wait().await; + } + + v + } + TestLoaderResponse::Panic => { + panic!("test") + } + } + } +} + +#[cfg(test)] +mod tests { + use futures::FutureExt; + + use super::*; + + #[tokio::test] + #[should_panic(expected = "entry not mocked")] + async fn test_loader_panic_entry_unknown() { + let loader = TestLoader::::default(); + loader.load(1, ()).await; + } + + #[tokio::test] + #[should_panic(expected = "no mocked response left")] + async fn test_loader_panic_no_mocked_reponse_left() { + let loader = TestLoader::default(); + loader.mock_next(1, String::from("foo")); + loader.load(1, ()).await; + loader.load(1, ()).await; + } + + #[test] + #[should_panic(expected = "mocked response left")] + fn test_loader_panic_requests_left() { + let loader = TestLoader::::default(); + loader.mock_next(1, String::from("foo")); + } + + #[test] + #[should_panic(expected = "panic-by-choice")] + fn test_loader_no_double_panic() { + let loader = TestLoader::::default(); + loader.mock_next(1, String::from("foo")); + panic!("panic-by-choice"); + } + + #[tokio::test] + async fn test_loader_nonblocking_mock() { + let loader = TestLoader::default(); + + loader.mock_next(1, String::from("foo")); + loader.mock_next(1, String::from("bar")); + loader.mock_next(2, String::from("baz")); + + assert_eq!(loader.load(1, ()).await, String::from("foo")); + assert_eq!(loader.load(2, ()).await, String::from("baz")); + assert_eq!(loader.load(1, ()).await, String::from("bar")); + } + + #[tokio::test] + async fn test_loader_blocking_mock() { + let loader = Arc::new(TestLoader::default()); + + let loader_barrier = loader.block_next(1, String::from("foo")); + loader.mock_next(2, String::from("bar")); + + let is_blocked_barrier = Arc::new(Barrier::new(2)); + + let loader_captured = Arc::clone(&loader); + let is_blocked_barrier_captured = Arc::clone(&is_blocked_barrier); + let handle = tokio::task::spawn(async move { + let mut fut_load = loader_captured.load(1, ()).fuse(); + + futures::select_biased! { + _ = fut_load => { + panic!("should not finish"); + } + _ = is_blocked_barrier_captured.wait().fuse() => {} + } + fut_load.await + }); + + is_blocked_barrier.wait().await; + + // can still load other entries + assert_eq!(loader.load(2, ()).await, String::from("bar")); + + // unblock load + loader_barrier.wait().await; + assert_eq!(handle.await.unwrap(), String::from("foo")); + } +} diff --git a/cache_system/src/resource_consumption.rs b/cache_system/src/resource_consumption.rs new file mode 100644 index 0000000..c2d32ce --- /dev/null +++ b/cache_system/src/resource_consumption.rs @@ -0,0 +1,195 @@ +//! Reasoning about resource consumption of cached data. +use std::{ + fmt::Debug, + marker::PhantomData, + ops::{Add, Sub}, +}; + +/// Strongly-typed resource consumption. +/// +/// Can be used to represent in-RAM memory as well as on-disc memory. +pub trait Resource: + Add + + Copy + + Debug + + Into + + Ord + + PartialOrd + + Send + + Sync + + Sub + + 'static +{ + /// Create resource consumption of zero. + fn zero() -> Self; + + /// Unit name. + /// + /// This must be a single lowercase word. + fn unit() -> &'static str; +} + +/// An estimator of [`Resource`] consumption for a given key-value pair. +pub trait ResourceEstimator: Debug + Send + Sync + 'static { + /// Cache key. + type K; + + /// Cached value. + type V; + + /// Size that can be estimated. + type S: Resource; + + /// Estimate size of given key-value pair. + fn consumption(&self, k: &Self::K, v: &Self::V) -> Self::S; +} + +/// A simple function-based [`ResourceEstimator]. +/// +/// # Typing +/// Semantically this wrapper has only one degree of freedom: `F`, which is the estimator function. However until +/// [`fn_traits`] are stable, there is no way to extract the parameters and return value from a function via associated +/// types. So we need to add additional type parametes for the special `Fn(...) -> ...` handling. +/// +/// It is likely that `F` will be a closure, e.g.: +/// +/// ``` +/// use cache_system::resource_consumption::{ +/// FunctionEstimator, +/// test_util::TestSize, +/// }; +/// +/// let my_estimator = FunctionEstimator::new(|_k: &u8, v: &String| -> TestSize { +/// TestSize(std::mem::size_of::<(u8, String)>() + v.capacity()) +/// }); +/// ``` +/// +/// There is no way to spell out the exact type of `my_estimator` in the above example, because the closure has an +/// anonymous type. If you need the type signature of [`FunctionEstimator`], you have to +/// [erase the type](https://en.wikipedia.org/wiki/Type_erasure) by putting the [`FunctionEstimator`] it into a [`Box`], +/// e.g.: +/// +/// ``` +/// use cache_system::resource_consumption::{ +/// FunctionEstimator, +/// ResourceEstimator, +/// test_util::TestSize, +/// }; +/// +/// let my_estimator = FunctionEstimator::new(|_k: &u8, v: &String| -> TestSize { +/// TestSize(std::mem::size_of::<(u8, String)>() + v.capacity()) +/// }); +/// let my_estimator: Box> = Box::new(my_estimator); +/// ``` +/// +/// +/// [`fn_traits`]: https://doc.rust-lang.org/beta/unstable-book/library-features/fn-traits.html +pub struct FunctionEstimator +where + F: Fn(&K, &V) -> S + Send + Sync + 'static, + K: 'static, + V: 'static, + S: Resource, +{ + estimator: F, + _phantom: PhantomData (K, V, S) + Send + Sync + 'static>, +} + +impl FunctionEstimator +where + F: Fn(&K, &V) -> S + Send + Sync + 'static, + K: 'static, + V: 'static, + S: Resource, +{ + /// Create new resource estimator from given function. + pub fn new(f: F) -> Self { + Self { + estimator: f, + _phantom: PhantomData, + } + } +} + +impl std::fmt::Debug for FunctionEstimator +where + F: Fn(&K, &V) -> S + Send + Sync + 'static, + K: 'static, + V: 'static, + S: Resource, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FunctionEstimator").finish_non_exhaustive() + } +} + +impl ResourceEstimator for FunctionEstimator +where + F: Fn(&K, &V) -> S + Send + Sync + 'static, + K: 'static, + V: 'static, + S: Resource, +{ + type K = K; + type V = V; + type S = S; + + fn consumption(&self, k: &Self::K, v: &Self::V) -> Self::S { + (self.estimator)(k, v) + } +} + +pub mod test_util { + //! Helpers to test resource consumption-based algorithms. + use super::*; + + /// Simple resource type for testing. + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] + pub struct TestSize(pub usize); + + impl Resource for TestSize { + fn zero() -> Self { + Self(0) + } + + fn unit() -> &'static str { + "bytes" + } + } + + impl From for u64 { + fn from(s: TestSize) -> Self { + s.0 as Self + } + } + + impl Add for TestSize { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self(self.0.checked_add(rhs.0).expect("overflow")) + } + } + + impl Sub for TestSize { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self(self.0.checked_sub(rhs.0).expect("underflow")) + } + } +} + +#[cfg(test)] +mod tests { + use crate::resource_consumption::test_util::TestSize; + + use super::*; + + #[test] + fn test_function_estimator() { + let estimator = + FunctionEstimator::new(|k: &u8, v: &u16| TestSize((*k as usize) * 10 + (*v as usize))); + assert_eq!(estimator.consumption(&3, &2), TestSize(32)); + } +} diff --git a/cache_system/src/test_util.rs b/cache_system/src/test_util.rs new file mode 100644 index 0000000..959cc68 --- /dev/null +++ b/cache_system/src/test_util.rs @@ -0,0 +1,62 @@ +use std::{future::Future, sync::Arc, time::Duration}; + +use async_trait::async_trait; +use futures::FutureExt; +use tokio::{sync::Barrier, task::JoinHandle}; + +#[async_trait] +pub trait EnsurePendingExt { + type Out; + + /// Ensure that the future is pending. In the pending case, try to pass the given barrier. Afterwards await the future again. + /// + /// This is helpful to ensure a future is in a pending state before continuing with the test setup. + async fn ensure_pending(self, barrier: Arc) -> Self::Out; +} + +#[async_trait] +impl EnsurePendingExt for F +where + F: Future + Send + Unpin, +{ + type Out = F::Output; + + async fn ensure_pending(self, barrier: Arc) -> Self::Out { + let mut fut = self.fuse(); + futures::select_biased! { + _ = fut => panic!("fut should be pending"), + _ = barrier.wait().fuse() => (), + } + + fut.await + } +} + +#[async_trait] +pub trait AbortAndWaitExt { + /// Abort handle and wait for completion. + /// + /// Note that this is NOT just a "wait with timeout or panic". This extension is specific to [`JoinHandle`] and will: + /// + /// 1. Call [`JoinHandle::abort`]. + /// 2. Await the [`JoinHandle`] with a timeout (or panic if the timeout is reached). + /// 3. Check that the handle returned a [`JoinError`] that signals that the tracked task was indeed cancelled and + /// didn't exit otherwise (either by finishing or by panicking). + async fn abort_and_wait(self); +} + +#[async_trait] +impl AbortAndWaitExt for JoinHandle +where + T: std::fmt::Debug + Send, +{ + async fn abort_and_wait(mut self) { + self.abort(); + + let join_err = tokio::time::timeout(Duration::from_secs(1), self) + .await + .expect("no timeout") + .expect_err("handle was aborted and therefore MUST fail"); + assert!(join_err.is_cancelled()); + } +} diff --git a/catalog_cache/Cargo.toml b/catalog_cache/Cargo.toml new file mode 100644 index 0000000..cdb79c5 --- /dev/null +++ b/catalog_cache/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "catalog_cache" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[lints] +workspace = true + +[dependencies] +bytes = "1.5" +dashmap = "5.5" +futures = "0.3" +hyper = "0.14" +url = "2.5" +reqwest = { version = "0.11", default-features = false } +snafu = "0.8" +tokio = { version = "1.35", default-features = false, features = ["macros", "rt"] } +tokio-util = "0.7" +workspace-hack = { version = "0.1", path = "../workspace-hack" } + +[dev-dependencies] diff --git a/catalog_cache/src/api/client.rs b/catalog_cache/src/api/client.rs new file mode 100644 index 0000000..94e9bf9 --- /dev/null +++ b/catalog_cache/src/api/client.rs @@ -0,0 +1,176 @@ +//! Client for the cache HTTP API + +use crate::api::list::{ListDecoder, ListEntry, MAX_VALUE_SIZE}; +use crate::api::{RequestPath, GENERATION}; +use crate::{CacheKey, CacheValue}; +use bytes::{Buf, Bytes}; +use futures::prelude::*; +use futures::stream::BoxStream; +use reqwest::{Client, Response, StatusCode, Url}; +use snafu::{OptionExt, ResultExt, Snafu}; +use std::time::Duration; + +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +pub enum Error { + #[snafu(display("Creating client: {source}"))] + Client { source: reqwest::Error }, + + #[snafu(display("Put Reqwest error: {source}"))] + Put { source: reqwest::Error }, + + #[snafu(display("Get Reqwest error: {source}"))] + Get { source: reqwest::Error }, + + #[snafu(display("List Reqwest error: {source}"))] + List { source: reqwest::Error }, + + #[snafu(display("Health Reqwest error: {source}"))] + Health { source: reqwest::Error }, + + #[snafu(display("Missing generation header"))] + MissingGeneration, + + #[snafu(display("Invalid generation value"))] + InvalidGeneration, + + #[snafu(display("Error decoding list stream: {source}"), context(false))] + ListStream { source: crate::api::list::Error }, +} + +/// Result type for [`CatalogCacheClient`] +pub type Result = std::result::Result; + +/// The type returned by [`CatalogCacheClient::list`] +pub type ListStream = BoxStream<'static, Result>; + +const RESOURCE_REQUEST_TIMEOUT: Duration = Duration::from_secs(1); + +/// We use a longer timeout for list request as they may transfer a non-trivial amount of data +const LIST_REQUEST_TIMEOUT: Duration = Duration::from_secs(20); + +/// A client for accessing a remote catalog cache +#[derive(Debug)] +pub struct CatalogCacheClient { + client: Client, + endpoint: Url, +} + +impl CatalogCacheClient { + /// Create a new [`CatalogCacheClient`] with the given remote endpoint + pub fn try_new(endpoint: Url) -> Result { + let client = Client::builder() + .connect_timeout(Duration::from_secs(2)) + .build() + .context(ClientSnafu)?; + + Ok(Self { endpoint, client }) + } + + /// Retrieve the given value from the remote cache, if present + pub async fn get(&self, key: CacheKey) -> Result> { + let url = format!("{}{}", self.endpoint, RequestPath::Resource(key)); + let timeout = RESOURCE_REQUEST_TIMEOUT; + let req = self.client.get(url).timeout(timeout); + let resp = req.send().await.context(GetSnafu)?; + + if resp.status() == StatusCode::NOT_FOUND { + return Ok(None); + } + let resp = resp.error_for_status().context(GetSnafu)?; + + let generation = resp + .headers() + .get(&GENERATION) + .context(MissingGenerationSnafu)?; + + let generation = generation + .to_str() + .ok() + .and_then(|v| v.parse().ok()) + .context(InvalidGenerationSnafu)?; + + let data = resp.bytes().await.context(GetSnafu)?; + + Ok(Some(CacheValue::new(data, generation))) + } + + /// Upsert the given key-value pair to the remote cache + /// + /// Returns false if the value had a generation less than or equal to + /// an existing value + pub async fn put(&self, key: CacheKey, value: &CacheValue) -> Result { + let url = format!("{}{}", self.endpoint, RequestPath::Resource(key)); + + let response = self + .client + .put(url) + .timeout(RESOURCE_REQUEST_TIMEOUT) + .header(&GENERATION, value.generation) + .body(value.data.clone()) + .send() + .await + .context(PutSnafu)? + .error_for_status() + .context(PutSnafu)?; + + Ok(matches!(response.status(), StatusCode::OK)) + } + + /// List the contents of the remote cache + /// + /// Values larger than `max_value_size` will not be returned inline, with only the key + /// and generation returned instead. Defaults to [`MAX_VALUE_SIZE`] + pub fn list(&self, max_value_size: Option) -> ListStream { + let size = max_value_size.unwrap_or(MAX_VALUE_SIZE); + let url = format!("{}{}?size={size}", self.endpoint, RequestPath::List); + let fut = self.client.get(url).timeout(LIST_REQUEST_TIMEOUT).send(); + + futures::stream::once(fut.map_err(|source| Error::List { source })) + .and_then(move |response| futures::future::ready(list_stream(response, size))) + .try_flatten() + .boxed() + } +} + +struct ListStreamState { + response: Response, + current: Bytes, + decoder: ListDecoder, +} + +impl ListStreamState { + fn new(response: Response, max_value_size: usize) -> Self { + Self { + response, + current: Default::default(), + decoder: ListDecoder::new().with_max_value_size(max_value_size), + } + } +} + +fn list_stream( + response: Response, + max_value_size: usize, +) -> Result>> { + let resp = response.error_for_status().context(ListSnafu)?; + let state = ListStreamState::new(resp, max_value_size); + Ok(stream::try_unfold(state, |mut state| async move { + loop { + if state.current.is_empty() { + match state.response.chunk().await.context(ListSnafu)? { + Some(new) => state.current = new, + None => break, + } + } + + let to_read = state.current.len(); + let read = state.decoder.decode(&state.current)?; + state.current.advance(read); + if read != to_read { + break; + } + } + Ok(state.decoder.flush()?.map(|entry| (entry, state))) + })) +} diff --git a/catalog_cache/src/api/list.rs b/catalog_cache/src/api/list.rs new file mode 100644 index 0000000..155f794 --- /dev/null +++ b/catalog_cache/src/api/list.rs @@ -0,0 +1,467 @@ +//! The encoding mechanism for list streams +//! +//! This is capable of streaming both keys and values, this saves round-trips when hydrating +//! a cache from a remote, and avoids creating a flood of HTTP GET requests + +use bytes::Bytes; +use snafu::{ensure, Snafu}; + +use crate::{CacheKey, CacheValue}; + +/// Error type for list streams +#[derive(Debug, Snafu)] +#[allow(missing_copy_implementations, missing_docs)] +pub enum Error { + #[snafu(display("Unexpected EOF whilst decoding list stream"))] + UnexpectedEOF, + + #[snafu(display("List value of {size} bytes too large"))] + ValueTooLarge { size: usize }, +} + +/// Result type for list streams +pub type Result = std::result::Result; + +/// The size at which to flush [`Bytes`] from [`ListEncoder`] +pub const FLUSH_SIZE: usize = 1024 * 1024; // Flush in 1MB blocks + +/// The maximum value size to send in a [`ListEntry`] +/// +/// This primarily exists as a self-protection limit to prevent large or corrupted streams +/// from swamping the client, but also mitigates Head-Of-Line blocking resulting from +/// large cache values +pub const MAX_VALUE_SIZE: usize = 1024 * 10; + +/// Encodes [`ListEntry`] as an iterator of [`Bytes`] +/// +/// Each [`ListEntry`] is encoded as a `ListHeader`, followed by the value data +#[derive(Debug)] +pub struct ListEncoder { + /// The current offset into entries + offset: usize, + /// The ListEntry to encode + entries: Vec, + /// The flush size, made configurable for testing + flush_size: usize, + /// The maximum value size to write + max_value_size: usize, +} + +impl ListEncoder { + /// Create a new [`ListEncoder`] from the provided [`ListEntry`] + pub fn new(entries: Vec) -> Self { + Self { + entries, + offset: 0, + flush_size: FLUSH_SIZE, + max_value_size: MAX_VALUE_SIZE, + } + } + + /// Override the maximum value size to write + pub fn with_max_value_size(mut self, size: usize) -> Self { + self.max_value_size = size; + self + } +} + +impl Iterator for ListEncoder { + type Item = Bytes; + + fn next(&mut self) -> Option { + if self.offset == self.entries.len() { + return None; + } + + let mut cap = 0; + let mut end_offset = self.offset; + while end_offset < self.entries.len() && cap < self.flush_size { + match &self.entries[end_offset].data { + Some(d) if d.len() <= self.max_value_size => cap += ListHeader::SIZE + d.len(), + _ => cap += ListHeader::SIZE, + }; + end_offset += 1; + } + + let mut buf = Vec::with_capacity(cap); + for entry in self.entries.iter().take(end_offset).skip(self.offset) { + match &entry.data { + Some(d) if d.len() <= self.max_value_size => { + buf.extend_from_slice(&entry.header(false).encode()); + buf.extend_from_slice(d) + } + _ => buf.extend_from_slice(&entry.header(true).encode()), + } + } + self.offset = end_offset; + Some(buf.into()) + } +} + +#[allow(non_snake_case)] +mod Flags { + /// The value is not included in this response + /// + /// [`ListEncoder`](super::ListEncoder) only sends inline values for values smaller than a + /// configured threshold + pub(crate) const HEAD: u8 = 1; +} + +/// The header encoded in a list stream +#[derive(Debug)] +struct ListHeader { + /// The size of the value + size: u32, + /// Reserved for future usage + reserved: u16, + /// A bitmask of [`Flags`] + flags: u8, + /// The variant of [`CacheKey`] + variant: u8, + /// The generation of this value + generation: u64, + /// The key contents of [`CacheKey`] + key: u128, +} + +impl ListHeader { + /// The encoded size of [`ListHeader`] + const SIZE: usize = 32; + + /// Encodes [`ListHeader`] to an array + fn encode(&self) -> [u8; Self::SIZE] { + let mut out = [0; Self::SIZE]; + out[..4].copy_from_slice(&self.size.to_le_bytes()); + out[4..6].copy_from_slice(&self.reserved.to_le_bytes()); + out[6] = self.flags; + out[7] = self.variant; + out[8..16].copy_from_slice(&self.generation.to_le_bytes()); + out[16..32].copy_from_slice(&self.key.to_le_bytes()); + out + } + + /// Decodes [`ListHeader`] from an array + fn decode(buf: [u8; Self::SIZE]) -> Self { + Self { + size: u32::from_le_bytes(buf[..4].try_into().unwrap()), + reserved: u16::from_le_bytes(buf[4..6].try_into().unwrap()), + flags: buf[6], + variant: buf[7], + generation: u64::from_le_bytes(buf[8..16].try_into().unwrap()), + key: u128::from_le_bytes(buf[16..32].try_into().unwrap()), + } + } +} + +/// The state for [`ListDecoder`] +#[derive(Debug)] +enum DecoderState { + /// Decoding a header, contains the decoded data and the current offset + Header([u8; ListHeader::SIZE], usize), + /// Decoding value data for the given [`ListHeader`] + Body(ListHeader, Vec), +} + +impl Default for DecoderState { + fn default() -> Self { + Self::Header([0; ListHeader::SIZE], 0) + } +} + +/// Decodes [`ListEntry`] from a stream of bytes +#[derive(Debug)] +pub struct ListDecoder { + state: DecoderState, + max_size: usize, +} + +impl Default for ListDecoder { + fn default() -> Self { + Self { + state: DecoderState::default(), + max_size: MAX_VALUE_SIZE, + } + } +} + +impl ListDecoder { + /// Create a new [`ListDecoder`] + pub fn new() -> Self { + Self::default() + } + + /// Overrides the maximum value to deserialize + /// + /// Values larger than this will result in an error + /// Defaults to [`MAX_VALUE_SIZE`] + pub fn with_max_value_size(mut self, size: usize) -> Self { + self.max_size = size; + self + } + + /// Decode an entry from `buf`, returning the number of bytes consumed + /// + /// This is meant to be used in combination with [`Self::flush`] + pub fn decode(&mut self, mut buf: &[u8]) -> Result { + let initial = buf.len(); + while !buf.is_empty() { + match &mut self.state { + DecoderState::Header(header, offset) => { + let to_read = buf.len().min(ListHeader::SIZE - *offset); + header[*offset..*offset + to_read].copy_from_slice(&buf[..to_read]); + *offset += to_read; + buf = &buf[to_read..]; + + if *offset == ListHeader::SIZE { + let header = ListHeader::decode(*header); + let size = header.size as _; + ensure!(size <= self.max_size, ValueTooLargeSnafu { size }); + self.state = DecoderState::Body(header, Vec::with_capacity(size)) + } + } + DecoderState::Body(header, value) => { + let to_read = buf.len().min(header.size as usize - value.len()); + if to_read == 0 { + break; + } + value.extend_from_slice(&buf[..to_read]); + buf = &buf[to_read..]; + } + } + } + Ok(initial - buf.len()) + } + + /// Flush the contents of this [`ListDecoder`] + /// + /// Returns `Ok(Some(entry))` if a record is fully decoded + /// Returns `Ok(None)` if no in-progress record + /// Otherwise returns an error + pub fn flush(&mut self) -> Result> { + match std::mem::take(&mut self.state) { + DecoderState::Body(header, value) if value.len() == header.size as usize => { + Ok(Some(ListEntry { + variant: header.variant, + key: header.key, + generation: header.generation, + data: ((header.flags & Flags::HEAD) == 0).then(|| value.into()), + })) + } + DecoderState::Header(_, 0) => Ok(None), + _ => Err(Error::UnexpectedEOF), + } + } +} + +/// A key value pair encoded as part of a list +/// +/// Unlike [`CacheKey`] and [`CacheValue`] this allows: +/// +/// * Non-fatal handling of unknown key variants +/// * The option to not include the value data, e.g. if too large +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct ListEntry { + variant: u8, + generation: u64, + key: u128, + data: Option, +} + +impl ListEntry { + /// Create a new [`ListEntry`] from the provided key and value + pub fn new(key: CacheKey, value: CacheValue) -> Self { + let (variant, key) = match key { + CacheKey::Namespace(v) => (b'n', v as _), + CacheKey::Table(v) => (b't', v as _), + CacheKey::Partition(v) => (b'p', v as _), + }; + + Self { + key, + variant, + generation: value.generation, + data: Some(value.data), + } + } + + /// The key if it matches a known variant of [`CacheKey`] + /// + /// Returns `None` otherwise + pub fn key(&self) -> Option { + match self.variant { + b't' => Some(CacheKey::Table(self.key as _)), + b'n' => Some(CacheKey::Namespace(self.key as _)), + b'p' => Some(CacheKey::Partition(self.key as _)), + _ => None, + } + } + + /// The generation of this entry + pub fn generation(&self) -> u64 { + self.generation + } + + /// Returns the value data if present + pub fn value(&self) -> Option<&Bytes> { + self.data.as_ref() + } + + /// Returns the [`ListHeader`] for a given [`ListEntry`] + fn header(&self, head: bool) -> ListHeader { + let generation = self.generation; + let (flags, size) = match (head, &self.data) { + (false, Some(data)) => (0, data.len() as u32), + _ => (Flags::HEAD, 0), + }; + + ListHeader { + size, + flags, + variant: self.variant, + key: self.key, + generation, + reserved: 0, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Buf; + use std::io::BufRead; + + fn decode_entries(mut r: R) -> Result> { + let mut decoder = ListDecoder::default(); + let iter = std::iter::from_fn(move || { + loop { + let buf = r.fill_buf().unwrap(); + if buf.is_empty() { + break; + } + let to_read = buf.len(); + let read = decoder.decode(buf).unwrap(); + r.consume(read); + if read != to_read { + break; + } + } + decoder.flush().transpose() + }); + iter.collect() + } + + #[test] + fn test_roundtrip() { + let expected = vec![ + ListEntry::new(CacheKey::Namespace(2), CacheValue::new("test".into(), 32)), + ListEntry::new(CacheKey::Namespace(6), CacheValue::new("3".into(), 4)), + ListEntry { + variant: 0, + key: u128::MAX, + generation: u64::MAX, + data: Some("unknown".into()), + }, + ListEntry::new(CacheKey::Table(6), CacheValue::new("3".into(), 23)), + ListEntry { + variant: b'p', + key: 45, + generation: 23, + data: None, + }, + ListEntry::new( + CacheKey::Partition(3), + CacheValue::new("bananas".into(), 23), + ), + ]; + + let encoded: Vec<_> = ListEncoder::new(expected.clone()).collect(); + assert_eq!(encoded.len(), 1); // Expect entries to be encoded in single flush + + for buf_size in [3, 5, 12] { + let reader = std::io::BufReader::with_capacity(buf_size, encoded[0].clone().reader()); + let entries = decode_entries(reader).unwrap(); + assert_eq!(entries, expected); + + // Invalid key should not be fatal + assert_eq!(entries[2].key(), None); + // Head response should not be fatal + assert_eq!(entries[4].value(), None); + } + } + + #[test] + fn test_empty() { + let data: Vec<_> = ListEncoder::new(vec![]).collect(); + assert_eq!(data.len(), 0); + + let entries = decode_entries(std::io::Cursor::new([0_u8; 0])).unwrap(); + assert_eq!(entries.len(), 0); + } + + #[test] + fn test_flush_size() { + let data = Bytes::from(vec![0; 128]); + let entries = (0..1024) + .map(|x| ListEntry::new(CacheKey::Namespace(x), CacheValue::new(data.clone(), 0))) + .collect(); + + let mut encoder = ListEncoder::new(entries); + encoder.flush_size = 1024; // Lower limit for test + + let mut remaining = 1024; + for block in encoder { + let expected = remaining.min(7); + assert_eq!(block.len(), (data.len() + ListHeader::SIZE) * expected); + let decoded = decode_entries(block.reader()).unwrap(); + assert_eq!(decoded.len(), expected); + remaining -= expected; + } + } + + #[test] + fn test_size_limit() { + let entries = vec![ + ListEntry::new( + CacheKey::Namespace(0), + CacheValue::new(vec![0; 128].into(), 0), + ), + ListEntry::new( + CacheKey::Namespace(1), + CacheValue::new(vec![0; 129].into(), 0), + ), + ListEntry::new( + CacheKey::Namespace(2), + CacheValue::new(vec![0; 128].into(), 0), + ), + ]; + + let mut encoder = ListEncoder::new(entries); + encoder.max_value_size = 128; // Artificially lower limit for test + + let encoded: Vec<_> = encoder.collect(); + assert_eq!(encoded.len(), 1); + + let decoded = decode_entries(encoded[0].clone().reader()).unwrap(); + assert_eq!(decoded[0].value().unwrap().len(), 128); + assert_eq!(decoded[1].value(), None); // Should omit value that is too large + assert_eq!(decoded[2].value().unwrap().len(), 128); + + let mut decoder = ListDecoder::new(); + decoder.max_size = 12; + let err = decoder.decode(&encoded[0]).unwrap_err().to_string(); + assert_eq!(err, "List value of 128 bytes too large"); + + let mut decoder = ListDecoder::new(); + decoder.max_size = 128; + + let consumed = decoder.decode(&encoded[0]).unwrap(); + let r = decoder.flush().unwrap().unwrap(); + assert_eq!(r.value().unwrap().len(), 128); + + // Next record skipped by encoder as too large + decoder.decode(&encoded[0][consumed..]).unwrap(); + let r = decoder.flush().unwrap().unwrap(); + assert_eq!(r.value(), None); + } +} diff --git a/catalog_cache/src/api/mod.rs b/catalog_cache/src/api/mod.rs new file mode 100644 index 0000000..66d4042 --- /dev/null +++ b/catalog_cache/src/api/mod.rs @@ -0,0 +1,159 @@ +//! The remote API for the catalog cache + +use crate::CacheKey; +use hyper::http::HeaderName; + +pub mod client; + +pub mod quorum; + +pub mod server; + +pub mod list; + +/// The header used to encode the generation in a get response +static GENERATION: HeaderName = HeaderName::from_static("x-influx-generation"); + +/// Defines the mapping to HTTP paths for given request types +#[derive(Debug, Eq, PartialEq)] +enum RequestPath { + /// A request addressing a resource identified by [`CacheKey`] + Resource(CacheKey), + /// A list request + List, +} + +impl RequestPath { + fn parse(s: &str) -> Option { + let s = s.strip_prefix('/').unwrap_or(s); + if s == "v1/" { + return Some(Self::List); + } + + let (prefix, value) = s.rsplit_once('/')?; + let value = u64::from_str_radix(value, 16).ok()?; + match prefix { + "v1/n" => Some(Self::Resource(CacheKey::Namespace(value as i64))), + "v1/t" => Some(Self::Resource(CacheKey::Table(value as i64))), + "v1/p" => Some(Self::Resource(CacheKey::Partition(value as i64))), + _ => None, + } + } +} + +impl std::fmt::Display for RequestPath { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::List => write!(f, "v1/"), + Self::Resource(CacheKey::Namespace(v)) => write!(f, "v1/n/{v:016x}"), + Self::Resource(CacheKey::Table(v)) => write!(f, "v1/t/{v:016x}"), + Self::Resource(CacheKey::Partition(v)) => write!(f, "v1/p/{v:016x}"), + } + } +} + +#[cfg(test)] +mod tests { + use crate::api::list::ListEntry; + use crate::api::server::test_util::TestCacheServer; + use crate::api::RequestPath; + use crate::{CacheKey, CacheValue}; + use futures::TryStreamExt; + use std::collections::HashSet; + + #[test] + fn test_request_path() { + let paths = [ + RequestPath::List, + RequestPath::Resource(CacheKey::Partition(12)), + RequestPath::Resource(CacheKey::Partition(i64::MAX)), + RequestPath::Resource(CacheKey::Partition(i64::MIN)), + RequestPath::Resource(CacheKey::Namespace(12)), + RequestPath::Resource(CacheKey::Namespace(i64::MAX)), + RequestPath::Resource(CacheKey::Namespace(i64::MIN)), + RequestPath::Resource(CacheKey::Table(12)), + RequestPath::Resource(CacheKey::Table(i64::MAX)), + RequestPath::Resource(CacheKey::Table(i64::MIN)), + ]; + + let mut set = HashSet::with_capacity(paths.len()); + for path in paths { + let s = path.to_string(); + let back = RequestPath::parse(&s).unwrap(); + assert_eq!(back, path); + assert!(set.insert(s), "should be unique"); + } + } + + #[tokio::test] + async fn test_basic() { + let serve = TestCacheServer::bind_ephemeral(); + let client = serve.client(); + + let key = CacheKey::Partition(1); + + let v1 = CacheValue::new("1".into(), 2); + assert!(client.put(key, &v1).await.unwrap()); + + let returned = client.get(key).await.unwrap().unwrap(); + assert_eq!(v1, returned); + + // Duplicate upsert ignored + assert!(!client.put(key, &v1).await.unwrap()); + + // Stale upsert ignored + let v2 = CacheValue::new("2".into(), 1); + assert!(!client.put(key, &v2).await.unwrap()); + + let returned = client.get(key).await.unwrap().unwrap(); + assert_eq!(v1, returned); + + let v3 = CacheValue::new("3".into(), 3); + assert!(client.put(key, &v3).await.unwrap()); + + let returned = client.get(key).await.unwrap().unwrap(); + assert_eq!(v3, returned); + + let key2 = CacheKey::Partition(5); + assert!(client.put(key2, &v1).await.unwrap()); + + let mut result = client.list(None).try_collect::>().await.unwrap(); + result.sort_unstable_by_key(|entry| entry.key()); + + let expected = vec![ListEntry::new(key, v3), ListEntry::new(key2, v1)]; + assert_eq!(result, expected); + + serve.shutdown().await; + } + + #[tokio::test] + async fn test_list_size() { + let serve = TestCacheServer::bind_ephemeral(); + let client = serve.client(); + + let v1 = CacheValue::new("123".into(), 2); + client.put(CacheKey::Table(1), &v1).await.unwrap(); + + let v2 = CacheValue::new("13".into(), 2); + client.put(CacheKey::Table(2), &v2).await.unwrap(); + + let v3 = CacheValue::new("1".into(), 2); + client.put(CacheKey::Table(3), &v3).await.unwrap(); + + let mut res = client.list(Some(2)).try_collect::>().await.unwrap(); + res.sort_unstable_by_key(|x| x.key()); + + assert_eq!(res.len(), 3); + + assert_eq!(res[0].value(), None); + assert_eq!(res[1].value(), Some(&v2.data)); + assert_eq!(res[2].value(), Some(&v3.data)); + + let mut res = client.list(Some(3)).try_collect::>().await.unwrap(); + res.sort_unstable_by_key(|x| x.key()); + + assert_eq!(res[0].value(), Some(&v1.data)); + assert_eq!(res[1].value(), Some(&v2.data)); + assert_eq!(res[2].value(), Some(&v3.data)); + } +} diff --git a/catalog_cache/src/api/quorum.rs b/catalog_cache/src/api/quorum.rs new file mode 100644 index 0000000..17c4edf --- /dev/null +++ b/catalog_cache/src/api/quorum.rs @@ -0,0 +1,459 @@ +//! Client for performing quorum catalog reads/writes + +use crate::api::client::{CatalogCacheClient, Error as ClientError}; +use crate::local::CatalogCache; +use crate::{CacheKey, CacheValue}; +use futures::channel::oneshot; +use futures::future::{select, Either}; +use futures::{pin_mut, StreamExt}; +use snafu::{ResultExt, Snafu}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::task::JoinError; +use tokio_util::sync::CancellationToken; + +/// Error for [`QuorumCatalogCache`] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("Failed to communicate with any remote replica: {source}"))] + NoRemote { source: ClientError }, + + #[snafu(display("Write task was aborted"))] + Cancelled, + + #[snafu(display("Join Error: {source}"))] + Join { source: JoinError }, + + #[snafu(display("Failed to establish a read quorum: {generations:?}"))] + Quorum { + generations: [Result, ClientError>; 3], + }, + + #[snafu(display("Failed to list replica: {source}"))] + List { source: ClientError }, + + #[snafu(display("Local cache error: {source}"), context(false))] + Local { source: crate::local::Error }, +} + +/// Result for [`QuorumCatalogCache`] +pub type Result = std::result::Result; + +/// Performs quorum reads and writes across a local [`CatalogCache`] and two [`CatalogCacheClient`] +#[derive(Debug)] +pub struct QuorumCatalogCache { + local: Arc, + replicas: Arc<[CatalogCacheClient; 2]>, + shutdown: CancellationToken, +} + +impl Drop for QuorumCatalogCache { + fn drop(&mut self) { + self.shutdown.cancel() + } +} + +impl QuorumCatalogCache { + /// Create a new [`QuorumCatalogCache`] + pub fn new(local: Arc, replicas: Arc<[CatalogCacheClient; 2]>) -> Self { + Self { + local, + replicas, + shutdown: CancellationToken::new(), + } + } + + /// Retrieve the given value from the remote cache + /// + /// Returns `None` if value is not present in a quorum of replicas + /// Returns [`Error::Quorum`] if cannot establish a read quorum + pub async fn get(&self, key: CacheKey) -> Result> { + let local = self.local.get(key); + + let fut1 = self.replicas[0].get(key); + let fut2 = self.replicas[1].get(key); + pin_mut!(fut1); + pin_mut!(fut2); + + match select(fut1, fut2).await { + Either::Left((result, fut)) | Either::Right((result, fut)) => match (local, result) { + (None, Ok(None)) => Ok(None), + (Some(l), Ok(Some(r))) if l.generation <= r.generation => { + // preempt write from remote to local that arrives late + if l.generation < r.generation { + self.local.insert(key, r.clone())?; + } + Ok(Some(r)) + } + (local, r1) => { + // r1 either failed or did not return anything + let r2 = fut.await; + match (local, r1, r2) { + (None, _, Ok(None)) | (_, Ok(None), Ok(None)) => Ok(None), + (Some(l), _, Ok(Some(r))) if l.generation <= r.generation => { + // preempt write from remote to local that arrives late + if l.generation < r.generation { + self.local.insert(key, r.clone())?; + } + Ok(Some(r)) + } + (local, Ok(Some(l)), Ok(Some(r))) if l.generation == r.generation => { + if local.map(|x| x.generation < l.generation).unwrap_or(true) { + self.local.insert(key, l.clone())?; + } + Ok(Some(l)) + } + (l, r1, r2) => Err(Error::Quorum { + generations: [ + Ok(l.map(|x| x.generation)), + r1.map(|x| x.map(|x| x.generation)), + r2.map(|x| x.map(|x| x.generation)), + ], + }), + } + } + }, + } + } + + /// Upsert the given key-value pair + /// + /// Returns Ok if able to replicate the write to a quorum + pub async fn put(&self, key: CacheKey, value: CacheValue) -> Result<()> { + self.local.insert(key, value.clone())?; + + let replicas = Arc::clone(&self.replicas); + let (sender, receiver) = oneshot::channel(); + + let fut = async move { + let fut1 = replicas[0].put(key, &value); + let fut2 = replicas[1].put(key, &value); + pin_mut!(fut1); + pin_mut!(fut2); + + match select(fut1, fut2).await { + Either::Left((r, fut)) | Either::Right((r, fut)) => { + let _ = sender.send(r); + fut.await + } + } + }; + + // We spawn a tokio task so that we can potentially continue to replicate + // to the second replica asynchronously once we receive an ok response + let cancel = self.shutdown.child_token(); + let handle = tokio::spawn(async move { + let cancelled = cancel.cancelled(); + pin_mut!(fut); + pin_mut!(cancelled); + match select(cancelled, fut).await { + Either::Left(_) => Err(Error::Cancelled), + Either::Right((Ok(_), _)) => Ok(()), + Either::Right((Err(source), _)) => Err(Error::NoRemote { source }), + } + }); + + match receiver.await { + Ok(Ok(_)) => Ok(()), + _ => match handle.await { + Ok(r) => r, + Err(source) => Err(Error::Join { source }), + }, + } + } + + /// Warm the local cache by performing quorum reads from the other two replicas + /// + /// This method should be called after this server has been participating in the write quorum + /// for a period of time, e.g. 1 minute. This avoids an issue where a quorum cannot be + /// established for in-progress writes. + pub async fn warm(&self) -> Result<()> { + // List doesn't return keys in any particular order + // + // We therefore build a hashmap with the keys from one replica and compare + // this against those returned by the other + // + // We don't need to consult the local `CatalogCache`, as we only need to insert + // if a read quorum can be established between the replicas and isn't present locally + let mut generations = HashMap::with_capacity(128); + let mut list = self.replicas[0].list(Some(0)); + while let Some(entry) = list.next().await.transpose().context(ListSnafu)? { + if let Some(k) = entry.key() { + generations.insert(k, entry.generation()); + } + } + + let mut list = self.replicas[1].list(None); + while let Some(entry) = list.next().await.transpose().context(ListSnafu)? { + if let Some(k) = entry.key() { + match (generations.get(&k), entry.value()) { + (Some(generation), Some(v)) if *generation == entry.generation() => { + let value = CacheValue::new(v.clone(), *generation); + // In the case that local already has the given version + // this will be a no-op + self.local.insert(k, value)?; + } + _ => {} + } + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::api::server::test_util::TestCacheServer; + use std::future::Future; + use std::task::Context; + use std::time::Duration; + + #[tokio::test] + async fn test_basic() { + let local = Arc::new(CatalogCache::default()); + let r1 = TestCacheServer::bind_ephemeral(); + let r2 = TestCacheServer::bind_ephemeral(); + + let replicas = Arc::new([r1.client(), r2.client()]); + let quorum = QuorumCatalogCache::new(Arc::clone(&local), Arc::clone(&replicas)); + + assert_eq!(quorum.get(CacheKey::Table(1)).await.unwrap(), None); + + let k1 = CacheKey::Table(1); + let k2 = CacheKey::Table(2); + let k3 = CacheKey::Table(3); + + let v1 = CacheValue::new("foo".into(), 2); + quorum.put(k1, v1.clone()).await.unwrap(); + quorum.put(k2, v1.clone()).await.unwrap(); + + let r = quorum.get(k2).await.unwrap().unwrap(); + assert_eq!(r, v1); + + // New value + let v2 = CacheValue::new("foo".into(), 4); + quorum.put(k2, v2.clone()).await.unwrap(); + + let r = quorum.get(k1).await.unwrap().unwrap(); + assert_eq!(r, v1); + + let r = quorum.get(k2).await.unwrap().unwrap(); + assert_eq!(r, v2); + + // Can remove value from one replica and still get quorum + r2.cache().delete(k2).unwrap(); + let r = quorum.get(k2).await.unwrap().unwrap(); + assert_eq!(r, v2); + + // Loss of two copies results in not found + r1.cache().delete(k2).unwrap(); + let r = quorum.get(k2).await.unwrap(); + assert_eq!(r, None); + + // Simulate stale value in r1 + r1.cache().insert(k2, v1.clone()).unwrap(); + let err = quorum.get(k2).await.unwrap_err(); + assert!(matches!(err, Error::Quorum { .. }), "{err}"); + + // If quorum has stale value follows quorum + r2.cache().delete(k2); + r2.cache().insert(k2, v1.clone()).unwrap(); + let r = quorum.get(k2).await.unwrap().unwrap(); + assert_eq!(r, v1); + + // Simulate loss of replica 2 + r2.shutdown().await; + + // Can still establish a write quorum + quorum.put(k3, v1.clone()).await.unwrap(); + + // Can read newly inserted value + let r = quorum.get(k3).await.unwrap().unwrap(); + assert_eq!(r, v1); + + // Can still read from quorum of k1 + let r = quorum.get(k1).await.unwrap().unwrap(); + assert_eq!(r, v1); + + // Cannot get quorum as lost single node and local disagrees with replica 1 + let err = quorum.get(k2).await.unwrap_err(); + assert!(matches!(err, Error::Quorum { .. }), "{err}"); + + // Can establish quorum following write + quorum.put(k2, v2.clone()).await.unwrap(); + let r = quorum.get(k2).await.unwrap().unwrap(); + assert_eq!(r, v2); + + // Still cannot establish quorum + r1.cache().delete(k2); + let err = quorum.get(k2).await.unwrap_err(); + assert!(matches!(err, Error::Quorum { .. }), "{err}"); + + // k2 is now no longer present anywhere, can establish quorum + local.delete(k2); + let r = quorum.get(k2).await.unwrap(); + assert_eq!(r, None); + + // Simulate loss of replica 1 (in addition to replica 2) + r1.shutdown().await; + + // Can no longer get quorum for anything + let err = quorum.get(k1).await.unwrap_err(); + assert!(matches!(err, Error::Quorum { .. }), "{err}"); + } + + #[tokio::test] + async fn test_read_through() { + let local = Arc::new(CatalogCache::default()); + let r1 = TestCacheServer::bind_ephemeral(); + let r2 = TestCacheServer::bind_ephemeral(); + + let replicas = Arc::new([r1.client(), r2.client()]); + let quorum = QuorumCatalogCache::new(Arc::clone(&local), Arc::clone(&replicas)); + + let key = CacheKey::Table(1); + let v0 = CacheValue::new("v0".into(), 0); + + r1.cache().insert(key, v0.clone()).unwrap(); + r2.cache().insert(key, v0.clone()).unwrap(); + + let result = quorum.get(key).await.unwrap().unwrap(); + assert_eq!(result, v0); + + // Should have read-through to local + assert_eq!(local.get(key).unwrap(), v0); + + let v1 = CacheValue::new("v1".into(), 1); + let v2 = CacheValue::new("v2".into(), 2); + + r1.cache().insert(key, v1.clone()).unwrap(); + r2.cache().insert(key, v2.clone()).unwrap(); + + // A quorum request will get either v1 or v2 depending on which it contacts first + let result = quorum.get(key).await.unwrap().unwrap(); + assert!(result == v1 || result == v2, "{result:?}"); + + // Should read-through + assert_eq!(local.get(key).unwrap(), result); + + // Update r1 with version 2 + r1.cache().insert(key, v2.clone()).unwrap(); + + let result = quorum.get(key).await.unwrap().unwrap(); + assert_eq!(result, v2); + + // Should read-through + assert_eq!(local.get(key).unwrap(), v2); + + let v3 = CacheValue::new("v3".into(), 3); + local.insert(key, v3.clone()).unwrap(); + + // Should establish quorum for v2 even though local is v3 + let result = quorum.get(key).await.unwrap().unwrap(); + assert_eq!(result, v2); + + // Should not read-through + assert_eq!(local.get(key).unwrap(), v3); + + let v4 = CacheValue::new("v4".into(), 4); + let v5 = CacheValue::new("v5".into(), 5); + + local.insert(key, v5.clone()).unwrap(); + r1.cache().insert(key, v4.clone()).unwrap(); + + // Should fail as cannot establish quorum of three different versions of `[5, 4, 2]` + // and has latest version locally + let err = quorum.get(key).await.unwrap_err(); + assert!(matches!(err, Error::Quorum { .. }), "{err}"); + assert_eq!(local.get(key).unwrap(), v5); + + let v6 = CacheValue::new("v6".into(), 6); + r1.cache().insert(key, v6.clone()).unwrap(); + + // Should succeed as r1 has newer version than local + let result = quorum.get(key).await.unwrap().unwrap(); + assert_eq!(result, v6); + + // Should read-through + assert_eq!(local.get(key).unwrap(), v6); + } + + #[tokio::test] + async fn test_warm() { + let local = Arc::new(CatalogCache::default()); + let r1 = TestCacheServer::bind_ephemeral(); + let r2 = TestCacheServer::bind_ephemeral(); + + let replicas = Arc::new([r1.client(), r2.client()]); + let quorum = QuorumCatalogCache::new(local, Arc::clone(&replicas)); + + let k1 = CacheKey::Table(1); + let v1 = CacheValue::new("v1".into(), 1); + quorum.put(k1, v1.clone()).await.unwrap(); + + let k2 = CacheKey::Table(2); + let v2 = CacheValue::new("v2".into(), 1); + quorum.put(k2, v2.clone()).await.unwrap(); + + // Simulate local restart + let local = Arc::new(CatalogCache::default()); + let quorum = QuorumCatalogCache::new(Arc::clone(&local), Arc::clone(&replicas)); + + assert_eq!(local.list().count(), 0); + + quorum.warm().await.unwrap(); + + // Should populate both entries + let mut entries: Vec<_> = local.list().collect(); + entries.sort_unstable_by_key(|(k, _)| *k); + assert_eq!(entries, vec![(k1, v1.clone()), (k2, v2.clone())]); + + // Simulate local restart + let local = Arc::new(CatalogCache::default()); + let quorum = QuorumCatalogCache::new(Arc::clone(&local), Arc::clone(&replicas)); + + // Simulate in-progress write + let v3 = CacheValue::new("v3".into(), 2); + assert!(r1.cache().insert(k2, v3.clone()).unwrap()); + + // Cannot establish quorum for k1 so should skip over + quorum.warm().await.unwrap(); + let entries: Vec<_> = local.list().collect(); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0], (k1, v1.clone())); + + // If r2 updated warming should pick up new quorum + assert!(r2.cache().insert(k2, v3.clone()).unwrap()); + quorum.warm().await.unwrap(); + let mut entries: Vec<_> = local.list().collect(); + entries.sort_unstable_by_key(|(k, _)| *k); + assert_eq!(entries, vec![(k1, v1), (k2, v3)]); + + // Test cancellation safety + let k3 = CacheKey::Table(3); + let fut = quorum.put(k3, v2.clone()); + { + // `fut` is dropped (cancelled) on exit from this code block + pin_mut!(fut); + + let noop_waker = futures::task::noop_waker(); + let mut cx = Context::from_waker(&noop_waker); + assert!(fut.poll(&mut cx).is_pending()); + } + + // Write should still propagate asynchronously + let mut attempts = 0; + loop { + tokio::time::sleep(Duration::from_millis(1)).await; + match quorum.get(k3).await { + Ok(Some(_)) => break, + _ => { + assert!(attempts < 100); + attempts += 1; + } + } + } + } +} diff --git a/catalog_cache/src/api/server.rs b/catalog_cache/src/api/server.rs new file mode 100644 index 0000000..b29d841 --- /dev/null +++ b/catalog_cache/src/api/server.rs @@ -0,0 +1,300 @@ +//! Server for the cache HTTP API + +use crate::api::list::{ListEncoder, ListEntry}; +use crate::api::{RequestPath, GENERATION}; +use crate::local::CatalogCache; +use crate::CacheValue; +use futures::ready; +use hyper::body::HttpBody; +use hyper::header::ToStrError; +use hyper::http::request::Parts; +use hyper::service::Service; +use hyper::{Body, Method, Request, Response, StatusCode}; +use snafu::{OptionExt, ResultExt, Snafu}; +use std::convert::Infallible; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +enum Error { + #[snafu(display("Http error: {source}"), context(false))] + Http { source: hyper::http::Error }, + + #[snafu(display("Hyper error: {source}"), context(false))] + Hyper { source: hyper::Error }, + + #[snafu(display("Local cache error: {source}"), context(false))] + Local { source: crate::local::Error }, + + #[snafu(display("Non UTF-8 Header: {source}"))] + BadHeader { source: ToStrError }, + + #[snafu(display("Request missing generation header"))] + MissingGeneration, + + #[snafu(display("Invalid generation header: {source}"))] + InvalidGeneration { source: std::num::ParseIntError }, + + #[snafu(display("List query missing size"))] + MissingSize, + + #[snafu(display("List query invalid size: {source}"))] + InvalidSize { source: std::num::ParseIntError }, +} + +impl Error { + /// Convert an error into a [`Response`] + fn response(self) -> Response { + let mut response = Response::new(Body::from(self.to_string())); + *response.status_mut() = match &self { + Self::Http { .. } | Self::Hyper { .. } | Self::Local { .. } => { + StatusCode::INTERNAL_SERVER_ERROR + } + Self::InvalidGeneration { .. } + | Self::MissingGeneration + | Self::InvalidSize { .. } + | Self::MissingSize + | Self::BadHeader { .. } => StatusCode::BAD_REQUEST, + }; + response + } +} + +/// A [`Service`] that wraps a [`CatalogCache`] +#[derive(Debug, Clone)] +pub struct CatalogCacheService(Arc); + +/// Shared state for [`CatalogCacheService`] +#[derive(Debug)] +struct ServiceState { + cache: Arc, +} + +impl Service> for CatalogCacheService { + type Response = Response; + + type Error = Infallible; + type Future = CatalogRequestFuture; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + let (parts, body) = req.into_parts(); + CatalogRequestFuture { + parts, + body, + buffer: vec![], + state: Arc::clone(&self.0), + } + } +} + +/// The future for [`CatalogCacheService`] +#[derive(Debug)] +pub struct CatalogRequestFuture { + /// The request body + body: Body, + /// The request parts + parts: Parts, + /// The in-progress body + /// + /// We use Vec not Bytes to ensure the cache isn't storing slices of large allocations + buffer: Vec, + /// The cache to service requests + state: Arc, +} + +impl Future for CatalogRequestFuture { + type Output = Result, Infallible>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let r = loop { + match ready!(Pin::new(&mut self.body).poll_data(cx)) { + Some(Ok(b)) => self.buffer.extend_from_slice(&b), + Some(Err(e)) => break Err(e.into()), + None => break Ok(()), + } + }; + Poll::Ready(Ok(match r.and_then(|_| self.call()) { + Ok(resp) => resp, + Err(e) => e.response(), + })) + } +} + +impl CatalogRequestFuture { + fn call(&mut self) -> Result, Error> { + let body = std::mem::take(&mut self.buffer); + + let status = match RequestPath::parse(self.parts.uri.path()) { + Some(RequestPath::List) => match self.parts.method { + Method::GET => { + let query = self.parts.uri.query().context(MissingSizeSnafu)?; + let mut parts = url::form_urlencoded::parse(query.as_bytes()); + let (_, size) = parts.find(|(k, _)| k == "size").context(MissingSizeSnafu)?; + let size = size.parse().context(InvalidSizeSnafu)?; + + let iter = self.state.cache.list(); + let entries = iter.map(|(k, v)| ListEntry::new(k, v)).collect(); + let encoder = ListEncoder::new(entries).with_max_value_size(size); + + let stream = futures::stream::iter(encoder.map(Ok::<_, Error>)); + let response = Response::builder().body(Body::wrap_stream(stream))?; + return Ok(response); + } + _ => StatusCode::METHOD_NOT_ALLOWED, + }, + Some(RequestPath::Resource(key)) => match self.parts.method { + Method::GET => match self.state.cache.get(key) { + Some(value) => { + let response = Response::builder() + .header(&GENERATION, value.generation) + .body(value.data.into())?; + return Ok(response); + } + None => StatusCode::NOT_FOUND, + }, + Method::PUT => { + let headers = &self.parts.headers; + let generation = headers.get(&GENERATION).context(MissingGenerationSnafu)?; + let generation = generation.to_str().context(BadHeaderSnafu)?; + let generation = generation.parse().context(InvalidGenerationSnafu)?; + let value = CacheValue::new(body.into(), generation); + + match self.state.cache.insert(key, value)? { + true => StatusCode::OK, + false => StatusCode::NOT_MODIFIED, + } + } + Method::DELETE => { + self.state.cache.delete(key); + StatusCode::OK + } + _ => StatusCode::METHOD_NOT_ALLOWED, + }, + None => StatusCode::NOT_FOUND, + }; + + let mut response = Response::new(Body::empty()); + *response.status_mut() = status; + Ok(response) + } +} + +/// Runs a [`CatalogCacheService`] in a background task +/// +/// Will abort the background task on drop +#[derive(Debug)] +pub struct CatalogCacheServer { + state: Arc, +} + +impl CatalogCacheServer { + /// Create a new [`CatalogCacheServer`]. + /// + /// Note that the HTTP interface needs to be wired up in some higher-level structure. Use [`service`](Self::service) + /// for that. + pub fn new(cache: Arc) -> Self { + let state = Arc::new(ServiceState { cache }); + + Self { state } + } + + /// Returns HTTP service. + pub fn service(&self) -> CatalogCacheService { + CatalogCacheService(Arc::clone(&self.state)) + } + + /// Returns a reference to the [`CatalogCache`] of this server + pub fn cache(&self) -> &Arc { + &self.state.cache + } +} + +/// Test utilities. +pub mod test_util { + use std::{net::SocketAddr, ops::Deref}; + + use hyper::{service::make_service_fn, Server}; + use tokio::task::JoinHandle; + use tokio_util::sync::CancellationToken; + + use crate::api::client::CatalogCacheClient; + + use super::*; + + /// Test runner for a [`CatalogCacheServer`]. + #[derive(Debug)] + pub struct TestCacheServer { + addr: SocketAddr, + server: CatalogCacheServer, + shutdown: CancellationToken, + handle: Option>, + } + + impl TestCacheServer { + /// Create a new [`TestCacheServer`] bound to an ephemeral port + pub fn bind_ephemeral() -> Self { + Self::bind(&SocketAddr::from(([127, 0, 0, 1], 0))) + } + + /// Create a new [`CatalogCacheServer`] bound to the provided [`SocketAddr`] + pub fn bind(addr: &SocketAddr) -> Self { + let server = CatalogCacheServer::new(Arc::new(CatalogCache::default())); + let service = server.service(); + let make_service = make_service_fn(move |_conn| { + futures::future::ready(Ok::<_, Infallible>(service.clone())) + }); + + let hyper_server = Server::bind(addr).serve(make_service); + let addr = hyper_server.local_addr(); + + let shutdown = CancellationToken::new(); + let signal = shutdown.clone().cancelled_owned(); + let graceful = hyper_server.with_graceful_shutdown(signal); + let handle = Some(tokio::spawn(async move { graceful.await.unwrap() })); + + Self { + addr, + server, + shutdown, + handle, + } + } + + /// Returns a [`CatalogCacheClient`] for communicating with this server + pub fn client(&self) -> CatalogCacheClient { + let addr = format!("http://{}", self.addr); + CatalogCacheClient::try_new(addr.parse().unwrap()).unwrap() + } + + /// Triggers and waits for graceful shutdown + pub async fn shutdown(mut self) { + self.shutdown.cancel(); + if let Some(x) = self.handle.take() { + x.await.unwrap() + } + } + } + + impl Deref for TestCacheServer { + type Target = CatalogCacheServer; + + fn deref(&self) -> &Self::Target { + &self.server + } + } + + impl Drop for TestCacheServer { + fn drop(&mut self) { + if let Some(x) = &self.handle { + x.abort() + } + } + } +} diff --git a/catalog_cache/src/lib.rs b/catalog_cache/src/lib.rs new file mode 100644 index 0000000..0370448 --- /dev/null +++ b/catalog_cache/src/lib.rs @@ -0,0 +1,143 @@ +//! Consistent cache system used by the catalog service +//! +//! # Design +//! +//! The catalog service needs to be able to service queries without needing to communicate +//! with its underlying backing store. This serves the dual purpose of reducing load on this +//! backing store, and also returning results in a more timely manner. +//! +//! This caching must be transparent to the users of the catalog service, and therefore must not +//! introduce eventually consistent behaviour, or other consistency effects. +//! +//! As such this crate provides a strongly-consistent, distributed key-value cache. +//! +//! In order to keep things simple, this only provides a mapping from [`CacheKey`] to opaque +//! binary payloads, with no support for structured payloads. +//! +//! This avoids: +//! +//! * Complex replicated state machines +//! * Forward compatibility challenges where newer data can't roundtrip through older servers +//! * Simple to introspect, debug and reason about +//! * Predictable and easily quantifiable memory usage +//! +//! However, it does have the following implications: +//! +//! * Care must be taken to ensure that parsing of the cached payloads does not become a bottleneck +//! * Large values (> 1MB) should be avoided, as updates will resend the entire value +//! +//! ## Components +//! +//! This crate is broken into multiple parts +//! +//! * [`CatalogCache`] provides a local key value store +//! * [`CatalogCacheService`] exposes this [`CatalogCache`] over an HTTP API +//! * [`CatalogCacheClient`] communicates with a remote [`CatalogCacheService`] +//! * [`QuorumCatalogCache`] combines the above into a strongly-consistent distributed cache +//! +//! [`CatalogCache`]: local::CatalogCache +//! [`CatalogCacheClient`]: api::client::CatalogCacheClient +//! [`CatalogCacheService`]: api::server::CatalogCacheService +//! [`QuorumCatalogCache`]: api::quorum::QuorumCatalogCache +//! +#![deny(rustdoc::broken_intra_doc_links, rust_2018_idioms)] +#![warn( + missing_copy_implementations, + missing_docs, + clippy::explicit_iter_loop, + // See https://github.com/influxdata/influxdb_iox/pull/1671 + clippy::future_not_send, + clippy::use_self, + clippy::clone_on_ref_ptr, + clippy::todo, + clippy::dbg_macro, + unused_crate_dependencies +)] + +// Workaround for "unused crate" lint false positives. +use workspace_hack as _; + +use bytes::Bytes; +use std::sync::atomic::AtomicBool; + +pub mod api; +pub mod local; + +/// The types of catalog cache key +#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] +pub enum CacheKey { + /// A catalog namespace + Namespace(i64), + /// A catalog table + Table(i64), + /// A catalog partition + Partition(i64), +} + +impl CacheKey { + /// Variant as string. + /// + /// This can be used for logging and metrics. + pub fn variant(&self) -> &'static str { + match self { + Self::Namespace(_) => "namespace", + Self::Table(_) => "table", + Self::Partition(_) => "partition", + } + } + + /// Untyped ID. + pub fn id(&self) -> i64 { + match self { + Self::Namespace(id) => *id, + Self::Table(id) => *id, + Self::Partition(id) => *id, + } + } +} + +/// A value stored in [`CatalogCache`](local::CatalogCache) +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct CacheValue { + /// The data stored for this cache + data: Bytes, + /// The generation of this cache data + generation: u64, +} + +impl CacheValue { + /// Create a new [`CacheValue`] with the provided `data` and `generation` + pub fn new(data: Bytes, generation: u64) -> Self { + Self { data, generation } + } + + /// The data stored for this cache + pub fn data(&self) -> &Bytes { + &self.data + } + + /// The generation of this cache data + pub fn generation(&self) -> u64 { + self.generation + } +} + +/// Combines a [`CacheValue`] with an [`AtomicBool`] for the purposes of NRU-eviction +#[derive(Debug)] +struct CacheEntry { + /// The value of this cache entry + value: CacheValue, + /// An atomic flag that is set to `true` by `CatalogCache::get` and + /// cleared by `CatalogCache::evict_unused` + used: AtomicBool, +} + +impl From for CacheEntry { + fn from(value: CacheValue) -> Self { + Self { + value, + // Values start used to prevent racing with `evict_unused` + used: AtomicBool::new(true), + } + } +} diff --git a/catalog_cache/src/local/limit.rs b/catalog_cache/src/local/limit.rs new file mode 100644 index 0000000..6c38fee --- /dev/null +++ b/catalog_cache/src/local/limit.rs @@ -0,0 +1,82 @@ +//! A memory limiter + +use super::{Error, Result}; +use std::sync::atomic::{AtomicUsize, Ordering}; + +#[derive(Debug)] +pub(crate) struct MemoryLimiter { + current: AtomicUsize, + limit: usize, +} + +impl MemoryLimiter { + /// Create a new [`MemoryLimiter`] limited to `limit` bytes + pub(crate) fn new(limit: usize) -> Self { + Self { + current: AtomicUsize::new(0), + limit, + } + } + + /// Reserve `size` bytes, returning an error if this would exceed the limit + pub(crate) fn reserve(&self, size: usize) -> Result<()> { + let limit = self.limit; + let max = limit + .checked_sub(size) + .ok_or(Error::TooLarge { size, limit })?; + + // We can use relaxed ordering as not relying on this to + // synchronise memory accesses beyond itself + self.current + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| { + // This cannot overflow as current + size <= limit + (current <= max).then_some(current + size) + }) + .map_err(|current| Error::OutOfMemory { + size, + current, + limit, + })?; + Ok(()) + } + + /// Free `size` bytes + pub(crate) fn free(&self, size: usize) { + self.current.fetch_sub(size, Ordering::Relaxed); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_limiter() { + let limiter = MemoryLimiter::new(100); + + limiter.reserve(20).unwrap(); + limiter.reserve(70).unwrap(); + + let err = limiter.reserve(20).unwrap_err().to_string(); + assert_eq!(err, "Cannot reserve additional 20 bytes for cache containing 90 bytes as would exceed limit of 100 bytes"); + + limiter.reserve(10).unwrap(); + limiter.reserve(0).unwrap(); + + let err = limiter.reserve(1).unwrap_err().to_string(); + assert_eq!(err, "Cannot reserve additional 1 bytes for cache containing 100 bytes as would exceed limit of 100 bytes"); + + limiter.free(10); + limiter.reserve(10).unwrap(); + + limiter.free(100); + + // Can add single value taking entire range + limiter.reserve(100).unwrap(); + limiter.free(100); + + // Protected against overflow + let err = limiter.reserve(usize::MAX).unwrap_err(); + assert!(matches!(err, Error::TooLarge { .. }), "{err}"); + } +} diff --git a/catalog_cache/src/local/mod.rs b/catalog_cache/src/local/mod.rs new file mode 100644 index 0000000..373dd62 --- /dev/null +++ b/catalog_cache/src/local/mod.rs @@ -0,0 +1,355 @@ +//! A local in-memory cache + +mod limit; + +use crate::local::limit::MemoryLimiter; +use crate::{CacheEntry, CacheKey, CacheValue}; +use dashmap::mapref::entry::Entry; +use dashmap::DashMap; +use snafu::Snafu; +use std::sync::atomic::Ordering; +use std::sync::Arc; + +/// Error for [`CatalogCache`] +#[derive(Debug, Snafu)] +#[allow(missing_docs, missing_copy_implementations)] +pub enum Error { + #[snafu(display("Cannot reserve additional {size} bytes for cache containing {current} bytes as would exceed limit of {limit} bytes"))] + OutOfMemory { + size: usize, + current: usize, + limit: usize, + }, + + #[snafu(display("Cannot reserve additional {size} bytes for cache as request exceeds total memory limit of {limit} bytes"))] + TooLarge { size: usize, limit: usize }, +} + +/// Result for [`CatalogCache`] +pub type Result = std::result::Result; + +/// A trait for observing updated to [`CatalogCache`] +/// +/// This can be used for injecting metrics, maintaining secondary indices or otherwise +/// +/// Note: members are invoked under locks in [`CatalogCache`] and should therefore +/// be short-running and not call back into [`CatalogCache`] +pub trait CatalogCacheObserver: std::fmt::Debug + Send + Sync { + /// Called before a value is potentially inserted into [`CatalogCache`] + /// + /// This is called regardless of it [`CatalogCache`] already contains the value + fn insert(&self, key: CacheKey, new: &CacheValue, old: Option<&CacheValue>); + + /// A key removed from the [`CatalogCache`] + fn evict(&self, key: CacheKey, value: &CacheValue); +} + +/// A concurrent Not-Recently-Used cache mapping [`CacheKey`] to [`CacheValue`] +#[derive(Debug, Default)] +pub struct CatalogCache { + map: DashMap, + observer: Option>, + limit: Option, +} + +impl CatalogCache { + /// Create a new `CatalogCache` with an optional memory limit + pub fn new(limit: Option) -> Self { + Self { + limit: limit.map(MemoryLimiter::new), + ..Default::default() + } + } + + /// Sets a [`CatalogCacheObserver`] for this [`CatalogCache`] + pub fn with_observer(self, observer: Arc) -> Self { + Self { + observer: Some(observer), + ..self + } + } + + /// Returns the value for `key` if it exists + pub fn get(&self, key: CacheKey) -> Option { + let entry = self.map.get(&key)?; + entry.used.store(true, Ordering::Relaxed); + Some(entry.value.clone()) + } + + /// Insert the given `value` into the cache + /// + /// Skips insertion and returns false iff an entry already exists with the + /// same or greater generation + pub fn insert(&self, key: CacheKey, value: CacheValue) -> Result { + match self.map.entry(key) { + Entry::Occupied(mut o) => { + let old = &o.get().value; + if value.generation <= old.generation { + return Ok(false); + } + if let Some(l) = &self.limit { + let new_len = value.data.len(); + let cur_len = old.data.len(); + match new_len > cur_len { + true => l.reserve(new_len - cur_len)?, + false => l.free(cur_len - new_len), + } + } + if let Some(v) = &self.observer { + v.insert(key, &value, Some(old)); + } + o.insert(value.into()); + } + Entry::Vacant(v) => { + if let Some(l) = &self.limit { + l.reserve(value.data.len())?; + } + if let Some(v) = &self.observer { + v.insert(key, &value, None); + } + v.insert(value.into()); + } + } + Ok(true) + } + + /// Removes the [`CacheValue`] for the given `key` if any + pub fn delete(&self, key: CacheKey) -> Option { + match self.map.entry(key) { + Entry::Occupied(o) => { + let old = &o.get().value; + if let Some(v) = &self.observer { + v.evict(key, old) + } + if let Some(l) = &self.limit { + l.free(old.data.len()) + } + Some(o.remove().value) + } + _ => None, + } + } + + /// Returns an iterator over the items in this cache + pub fn list(&self) -> CacheIterator<'_> { + CacheIterator(self.map.iter()) + } + + /// Evict all entries not accessed with [`CatalogCache::get`] or updated since + /// the last call to this function + /// + /// Periodically calling this provides a Not-Recently-Used eviction policy + pub fn evict_unused(&self) { + self.map.retain(|key, entry| { + let retain = entry.used.swap(false, Ordering::Relaxed); + if !retain { + if let Some(v) = &self.observer { + v.evict(*key, &entry.value); + } + if let Some(l) = &self.limit { + l.free(entry.value.data.len()); + } + } + retain + }); + } +} + +/// Iterator for [`CatalogCache`] +#[allow(missing_debug_implementations)] +pub struct CacheIterator<'a>(dashmap::iter::Iter<'a, CacheKey, CacheEntry>); + +impl<'a> Iterator for CacheIterator<'a> { + type Item = (CacheKey, CacheValue); + + fn next(&mut self) -> Option { + let value = self.0.next()?; + Some((*value.key(), value.value().value.clone())) + } + + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + use dashmap::DashSet; + + #[derive(Debug, Default)] + struct KeyObserver { + keys: DashSet, + } + + impl KeyObserver { + fn keys(&self) -> Vec { + let mut keys: Vec<_> = self.keys.iter().map(|k| *k).collect(); + keys.sort_unstable(); + keys + } + } + + impl CatalogCacheObserver for KeyObserver { + fn insert(&self, key: CacheKey, _new: &CacheValue, _old: Option<&CacheValue>) { + self.keys.insert(key); + } + + fn evict(&self, key: CacheKey, _value: &CacheValue) { + self.keys.remove(&key); + } + } + + #[test] + fn test_basic() { + let observer = Arc::new(KeyObserver::default()); + let cache = CatalogCache::default().with_observer(Arc::clone(&observer) as _); + + let v1 = CacheValue::new("1".into(), 5); + assert!(cache.insert(CacheKey::Table(0), v1.clone()).unwrap()); + assert_eq!(cache.get(CacheKey::Table(0)).unwrap(), v1); + + // Older generation rejected + assert!(!cache + .insert(CacheKey::Table(0), CacheValue::new("2".into(), 3)) + .unwrap()); + + // Value unchanged + assert_eq!(cache.get(CacheKey::Table(0)).unwrap(), v1); + + // Different key accepted + let v2 = CacheValue::new("2".into(), 5); + assert!(cache.insert(CacheKey::Table(1), v2.clone()).unwrap()); + assert_eq!(cache.get(CacheKey::Table(1)).unwrap(), v2); + + let v3 = CacheValue::new("3".into(), 0); + assert!(cache.insert(CacheKey::Partition(0), v3.clone()).unwrap()); + + // Newer generation updates + let v4 = CacheValue::new("4".into(), 6); + assert!(cache.insert(CacheKey::Table(0), v4.clone()).unwrap()); + + let mut values: Vec<_> = cache.list().collect(); + values.sort_unstable_by(|(a, _), (b, _)| a.cmp(b)); + + assert_eq!( + values, + vec![ + (CacheKey::Table(0), v4.clone()), + (CacheKey::Table(1), v2), + (CacheKey::Partition(0), v3), + ] + ); + assert_eq!( + observer.keys(), + vec![ + CacheKey::Table(0), + CacheKey::Table(1), + CacheKey::Partition(0) + ] + ); + + assert_eq!(cache.get(CacheKey::Namespace(0)), None); + assert_eq!(cache.delete(CacheKey::Namespace(0)), None); + + assert_eq!(cache.get(CacheKey::Table(0)).unwrap(), v4); + assert_eq!(cache.delete(CacheKey::Table(0)).unwrap(), v4); + assert_eq!(cache.get(CacheKey::Table(0)), None); + + assert_eq!(cache.list().count(), 2); + assert_eq!(observer.keys.len(), 2); + } + + #[test] + fn test_nru() { + let observer = Arc::new(KeyObserver::default()); + let cache = CatalogCache::default().with_observer(Arc::clone(&observer) as _); + + let value = CacheValue::new("1".into(), 0); + cache.insert(CacheKey::Namespace(0), value.clone()).unwrap(); + cache.insert(CacheKey::Partition(0), value.clone()).unwrap(); + cache.insert(CacheKey::Table(0), value.clone()).unwrap(); + + cache.evict_unused(); + // Inserted records should only be evicted on the next pass + assert_eq!(cache.list().count(), 3); + assert_eq!(observer.keys.len(), 3); + + // Updating a record marks it used + cache + .insert(CacheKey::Table(0), CacheValue::new("2".into(), 1)) + .unwrap(); + + // Fetching a record marks it used + cache.get(CacheKey::Partition(0)).unwrap(); + + // Insert a new record is used + cache.insert(CacheKey::Partition(1), value.clone()).unwrap(); + + cache.evict_unused(); + + // Namespace(0) evicted + let mut values: Vec<_> = cache.list().map(|(k, _)| k).collect(); + values.sort_unstable(); + let expected = vec![ + CacheKey::Table(0), + CacheKey::Partition(0), + CacheKey::Partition(1), + ]; + assert_eq!(values, expected); + assert_eq!(observer.keys(), expected); + + // Stale updates don't count as usage + assert!(!cache.insert(CacheKey::Partition(0), value).unwrap()); + + // Listing does not preserve recently used + assert_eq!(cache.list().count(), 3); + + cache.evict_unused(); + assert_eq!(cache.list().count(), 0); + assert_eq!(observer.keys.len(), 0) + } + + #[test] + fn test_limit() { + let cache = CatalogCache::new(Some(200)); + + let k1 = CacheKey::Table(1); + let k2 = CacheKey::Table(2); + let k3 = CacheKey::Table(3); + + let v_100 = Bytes::from(vec![0; 100]); + let v_20 = Bytes::from(vec![0; 20]); + + cache.insert(k1, CacheValue::new(v_100.clone(), 0)).unwrap(); + cache.insert(k2, CacheValue::new(v_100.clone(), 0)).unwrap(); + + let r = cache.insert(k3, CacheValue::new(v_20.clone(), 0)); + assert_eq!(r.unwrap_err().to_string(), "Cannot reserve additional 20 bytes for cache containing 200 bytes as would exceed limit of 200 bytes"); + + // Upsert k1 to 20 bytes + cache.insert(k1, CacheValue::new(v_20.clone(), 1)).unwrap(); + + // Can now insert k3 + cache.insert(k3, CacheValue::new(v_20.clone(), 0)).unwrap(); + + // Should evict nothing + cache.evict_unused(); + + // Cannot increase size of k3 to 100 + let r = cache.insert(k3, CacheValue::new(v_100.clone(), 1)); + assert_eq!(r.unwrap_err().to_string(), "Cannot reserve additional 80 bytes for cache containing 140 bytes as would exceed limit of 200 bytes"); + + cache.delete(k2).unwrap(); + cache.insert(k3, CacheValue::new(v_100.clone(), 1)).unwrap(); + + let r = cache.insert(k2, CacheValue::new(v_100.clone(), 1)); + assert_eq!(r.unwrap_err().to_string(), "Cannot reserve additional 100 bytes for cache containing 120 bytes as would exceed limit of 200 bytes"); + + // Should evict everything apart from k3 + cache.evict_unused(); + + cache.insert(k2, CacheValue::new(v_100.clone(), 1)).unwrap(); + } +} diff --git a/clap_blocks/Cargo.toml b/clap_blocks/Cargo.toml new file mode 100644 index 0000000..de5d836 --- /dev/null +++ b/clap_blocks/Cargo.toml @@ -0,0 +1,40 @@ +[package] +name = "clap_blocks" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[lints] +workspace = true + +[dependencies] +clap = { version = "4", features = ["derive", "env"] } +ed25519-dalek = { version = "2", features = ["pem"] } +futures = "0.3" +http = "0.2.11" +humantime = "2.1.0" +iox_catalog = { path = "../iox_catalog" } +iox_time = { path = "../iox_time" } +itertools = "0.12.0" +metric = { path = "../metric" } +non-empty-string = "0.2.4" +object_store = { workspace = true } +observability_deps = { path = "../observability_deps" } +parquet_cache = { path = "../parquet_cache" } +snafu = "0.8" +sysinfo = "0.30.5" +trace_exporters = { path = "../trace_exporters" } +trogging = { path = "../trogging", default-features = false, features = ["clap"] } +url = "2.4" +uuid = { version = "1", features = ["v4"] } +workspace-hack = { version = "0.1", path = "../workspace-hack" } + +[dev-dependencies] +tempfile = "3.9.0" +test_helpers = { path = "../test_helpers" } + +[features] +azure = ["object_store/azure"] # Optional Azure Object store support +gcp = ["object_store/gcp"] # Optional GCP object store support +aws = ["object_store/aws"] # Optional AWS / S3 object store support diff --git a/clap_blocks/src/bulk_ingest.rs b/clap_blocks/src/bulk_ingest.rs new file mode 100644 index 0000000..df383b5 --- /dev/null +++ b/clap_blocks/src/bulk_ingest.rs @@ -0,0 +1,274 @@ +//! CLI config for the router to enable bulk ingest APIs + +use ed25519_dalek::{ + pkcs8::{DecodePrivateKey, DecodePublicKey}, + SigningKey, VerifyingKey, +}; +use snafu::{ResultExt, Snafu}; +use std::{fs, io, path::PathBuf}; + +/// CLI config for bulk ingest. +#[derive(Debug, Clone, Default, clap::Parser)] +pub struct BulkIngestConfig { + /// Private signing key used for Parquet metadata returned from the `NewParquetMetadata` gRPC + /// API to prevent tampering/corruption of Parquet metadata provided by IOx to the process + /// preparing Parquet files for bulk ingest. + /// + /// This is a path to an Ed25519 private key file generated by OpenSSL with the command: + /// `openssl genpkey -algorithm ed25519 -out private-key-filename.pem` + /// + /// The public key used to verify signatures will be derived from this private key. Additional + /// public verification keys can be specified with + /// `-bulk-ingest-additional-verification-key-files` to support key rotation. + /// + /// If not specified, the `NewParquetMetadata` gRPC API will return unimplemented. + #[clap( + long = "bulk-ingest-metadata-signing-key-file", + env = "INFLUXDB_IOX_BULK_INGEST_METADATA_SIGNING_KEY_FILE" + )] + metadata_signing_key_file: Option, + + /// When in the process of rotating keys, specify paths to files containing public verification + /// keys of previously used private signing keys used for signing metadata in the past. + /// + /// These files can be derived from private key files with this OpenSSL command: + /// `openssl pkey -in private-key-filename.pem -pubout -out public-key-filename.pem` + /// + /// Example: "public-key-1.pem,public-key-2.pem" + /// + /// If verification of the metadata signature fails with the current public key derived from + /// the current signing key, these verification keys will be tested in order to allow older + /// signatures generated with the old key to still be validated. For best performance of + /// signature verification, specify the additional verification keys in order of most likely + /// candidates first (probably most recently used first). + /// + /// If no additional verification keys are specified, only the verification key associated with + /// the current metadata signing key will be used to validate signatures. + #[clap( + long = "bulk-ingest-additional-verification-key-files", + env = "INFLUXDB_IOX_BULK_INGEST_ADDITIONAL_VERIFICATION_KEY_FILES", + required = false, + num_args=1.., + value_delimiter = ',', + )] + additional_verification_key_files: Vec, + + /// Rather than using whatever object store configuration may have been specified as a source + /// of presigned upload URLs for bulk ingest, use a mock implementation that returns an upload + /// URL value that can be inspected but not used. + /// + /// Only useful for testing bulk ingest without setting up S3! Do not use this in production! + #[clap( + hide = true, + long = "bulk-ingest-use-mock-presigned-url-signer", + env = "INFLUXDB_IOX_BULK_INGEST_USE_MOCK_PRESIGNED_URL_SIGNER", + default_value = "false" + )] + pub use_mock_presigned_url_signer: bool, +} + +impl BulkIngestConfig { + /// Constructor for bulk ingest configuration. + pub fn new( + metadata_signing_key_file: Option, + additional_verification_key_files: Vec, + use_mock_presigned_url_signer: bool, + ) -> Self { + Self { + metadata_signing_key_file, + additional_verification_key_files, + use_mock_presigned_url_signer, + } + } +} + +impl TryFrom<&BulkIngestConfig> for Option { + type Error = BulkIngestConfigError; + + fn try_from(config: &BulkIngestConfig) -> Result { + config + .metadata_signing_key_file + .as_ref() + .map(|signing_key_file| { + let signing_key: SigningKey = fs::read_to_string(signing_key_file) + .context(ReadingSigningKeyFileSnafu { + filename: &signing_key_file, + }) + .and_then(|file_contents| { + DecodePrivateKey::from_pkcs8_pem(&file_contents).context( + DecodingSigningKeySnafu { + filename: signing_key_file, + }, + ) + })?; + + let additional_verifying_keys: Vec<_> = config + .additional_verification_key_files + .iter() + .map(|verification_key_file| { + fs::read_to_string(verification_key_file) + .context(ReadingVerifyingKeyFileSnafu { + filename: &verification_key_file, + }) + .and_then(|file_contents| { + DecodePublicKey::from_public_key_pem(&file_contents).context( + DecodingVerifyingKeySnafu { + filename: verification_key_file, + }, + ) + }) + }) + .collect::, _>>()?; + + Ok(BulkIngestKeys { + signing_key, + additional_verifying_keys, + }) + }) + .transpose() + } +} + +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +pub enum BulkIngestConfigError { + #[snafu(display("Could not read signing key from {}: {source}", filename.display()))] + ReadingSigningKeyFile { + filename: PathBuf, + source: io::Error, + }, + + #[snafu(display("Could not decode signing key from {}: {source}", filename.display()))] + DecodingSigningKey { + filename: PathBuf, + source: ed25519_dalek::pkcs8::Error, + }, + + #[snafu(display("Could not read verifying key from {}: {source}", filename.display()))] + ReadingVerifyingKeyFile { + filename: PathBuf, + source: io::Error, + }, + + #[snafu(display("Could not decode verifying key from {}: {source}", filename.display()))] + DecodingVerifyingKey { + filename: PathBuf, + source: ed25519_dalek::pkcs8::spki::Error, + }, +} + +/// Key values extracted from the files specified to the CLI. To get an instance, first create a +/// `BulkIngestConfig`, then call `try_from` to get a `Result` containing an +/// `Option` where the `Option` will be `Some` if the `BulkIngestConfig`'s +/// `metadata_signing_key_file` value is `Some`. +/// +/// If any filenames specified anywhere in the `BulkIngestConfig` can't be read or don't contain +/// valid key values, the `try_from` implementation will return an error. +#[derive(Debug)] +pub struct BulkIngestKeys { + /// The parsed private signing key value contained in the file specified to + /// `--bulk-ingest-metadata-signing-key-file`. + pub signing_key: SigningKey, + + /// If any files were specified in `--bulk-ingest-additional-verification-key-files`, this list + /// will contain their parsed public verification key values. + pub additional_verifying_keys: Vec, +} + +#[cfg(test)] +mod tests { + use super::*; + use clap::Parser; + use std::process::Command; + use test_helpers::{assert_contains, make_temp_file, tmp_dir}; + + #[test] + fn missing_signing_key_param() { + // No signing key file -> no keys + let config = BulkIngestConfig::try_parse_from(["something"]).unwrap(); + let keys: Option = (&config).try_into().unwrap(); + assert!(keys.is_none(), "expected None, got: {:?}", keys); + + // Even if there are additional verification key files; no signing key file means no keys + let config = BulkIngestConfig::try_parse_from([ + "something", + "--bulk-ingest-additional-verification-key-files", + "some-public-key-filename.pem", + ]) + .unwrap(); + let keys: Option = (&config).try_into().unwrap(); + assert!(keys.is_none(), "expected None, got: {:?}", keys); + } + + #[test] + fn signing_key_file_not_found() { + let nonexistent_filename = "do-not-create-a-file-with-this-name-or-this-test-will-fail"; + let config = BulkIngestConfig::try_parse_from([ + "something", + "--bulk-ingest-metadata-signing-key-file", + nonexistent_filename, + ]) + .unwrap(); + + let keys: Result, _> = (&config).try_into(); + let err = keys.unwrap_err(); + assert_contains!( + err.to_string(), + format!("Could not read signing key from {nonexistent_filename}") + ); + } + + #[test] + fn signing_key_file_contents_invalid() { + let signing_key_file = make_temp_file("not a valid signing key"); + let signing_key_filename = signing_key_file.path().display().to_string(); + + let config = BulkIngestConfig::try_parse_from([ + "something", + "--bulk-ingest-metadata-signing-key-file", + &signing_key_filename, + ]) + .unwrap(); + + let keys: Result, _> = (&config).try_into(); + let err = keys.unwrap_err(); + assert_contains!( + err.to_string(), + format!("Could not decode signing key from {signing_key_filename}") + ); + } + + #[test] + fn valid_signing_key_file_no_additional_key_files() { + let tmp_dir = tmp_dir().unwrap(); + let signing_key_filename = tmp_dir + .path() + .join("test-private-key.pem") + .display() + .to_string(); + Command::new("openssl") + .arg("genpkey") + .arg("-algorithm") + .arg("ed25519") + .arg("-out") + .arg(&signing_key_filename) + .output() + .unwrap(); + + let config = BulkIngestConfig::try_parse_from([ + "something", + "--bulk-ingest-metadata-signing-key-file", + &signing_key_filename, + ]) + .unwrap(); + + let keys: Result, _> = (&config).try_into(); + let keys = keys.unwrap().unwrap(); + let additional_keys = keys.additional_verifying_keys; + assert!( + additional_keys.is_empty(), + "expected additional keys to be empty, got {:?}", + additional_keys + ); + } +} diff --git a/clap_blocks/src/catalog_cache.rs b/clap_blocks/src/catalog_cache.rs new file mode 100644 index 0000000..a9b8543 --- /dev/null +++ b/clap_blocks/src/catalog_cache.rs @@ -0,0 +1,154 @@ +//! Config for the catalog cache server mode. + +use std::time::Duration; + +use itertools::Itertools; +use snafu::{OptionExt, Snafu}; +use url::{Host, Url}; + +use crate::memory_size::MemorySize; + +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +pub enum Error { + #[snafu(display("host '{host}' is not a prefix of '{prefix}'"))] + NotAPrefix { host: String, prefix: String }, + + #[snafu(display("host '{host}' is not a valid host"))] + NotAValidHost { host: String }, + + #[snafu(display("invalid url: {source}"))] + InvalidUrl { source: url::ParseError }, + + #[snafu(display("Expected exactly two peers"))] + InvalidPeers, +} + +/// CLI config for catalog configuration +#[derive(Debug, Clone, PartialEq, Eq, clap::Parser)] +pub struct CatalogConfig { + /// Host Name + /// + /// If provided, any matching entries in peers will be ignored + #[clap(long = "hostname", env = "INFLUXDB_IOX_HOSTNAME", value_parser = Host::parse)] + pub hostname: Option>, + + /// Peers + /// + /// Can be provided as a comma-separated list, or on the command line multiple times + #[clap( + long = "catalog-cache-peers", + env = "INFLUXDB_IOX_CATALOG_CACHE_PEERS", + required = false, + value_delimiter = ',' + )] + pub peers: Vec, + + /// Warmup delay. + /// + /// The warm-up (via dumping the cache of our peers) is delayed by the given time to make sure that we already + /// receive quorum writes. This ensure a gaplass transition / roll-out w/o any cache MISSes (esp. w/o any backend requests). + #[clap( + long = "catalog-cache-warmup-delay", + env = "INFLUXDB_IOX_CATALOG_CACHE_WARMUP_DELAY", + default_value = default_warmup_delay(), + value_parser = humantime::parse_duration, + )] + pub warmup_delay: Duration, + + /// Garbage collection interval. + /// + /// Every time this interval past, cache elements that have not been used (i.e. read or updated) since the last time + /// are evicted from the cache. + #[clap( + long = "catalog-cache-gc-interval", + env = "INFLUXDB_IOX_CATALOG_CACHE_GC_INTERVAL", + default_value = default_gc_interval(), + value_parser = humantime::parse_duration, + )] + pub gc_interval: Duration, + + /// Maximum number of bytes that should be cached within the catalog cache. + /// + /// If that limit is exceeded, no new values are accepted. This is meant as a safety measurement. You should adjust + /// your pod size and the GC interval (`--catalog-cache-gc-interval` / `INFLUXDB_IOX_CATALOG_CACHE_GC_INTERVAL`) to + /// your workload. + /// + /// Can be given as absolute value or in percentage of the total available memory (e.g. `10%`). + #[clap( + long = "catalog-cache-size-limit", + env = "INFLUXDB_IOX_CATALOG_CACHE_SIZE_LIMIT", + default_value = "1073741824", // 1GB + action + )] + pub cache_size_limit: MemorySize, + + /// Number of concurrent quorum operations that a single request can trigger. + #[clap( + long = "catalog-cache-quorum-fanout", + env = "INFLUXDB_IOX_CATALOG_CACHE_QUORUM_FANOUT", + default_value_t = 10 + )] + pub quorum_fanout: usize, +} + +impl CatalogConfig { + /// Return URL of other catalog cache nodes. + pub fn peers(&self) -> Result<[Url; 2], Error> { + let (peer1, peer2) = self + .peers + .iter() + .filter(|x| match (x.host(), &self.hostname) { + (Some(a), Some(r)) => &a != r, + _ => true, + }) + .collect_tuple() + .context(InvalidPeersSnafu)?; + + Ok([peer1.clone(), peer2.clone()]) + } +} + +fn default_warmup_delay() -> &'static str { + let s = humantime::format_duration(Duration::from_secs(60 * 5)).to_string(); + Box::leak(Box::new(s)) +} + +fn default_gc_interval() -> &'static str { + let s = humantime::format_duration(Duration::from_secs(60 * 15)).to_string(); + Box::leak(Box::new(s)) +} + +#[cfg(test)] +mod tests { + use super::*; + use clap::Parser; + + #[test] + fn test_peers() { + let config = CatalogConfig::parse_from([ + "binary", + "--catalog-cache-peers", + "http://peer1:8080", + "--catalog-cache-peers", + "http://peer2:9090", + ]); + let peer1 = Url::parse("http://peer1:8080").unwrap(); + let peer2 = Url::parse("http://peer2:9090").unwrap(); + + let peers = config.peers().unwrap(); + assert_eq!(peers, [peer1.clone(), peer2.clone()]); + + let mut config = CatalogConfig::parse_from([ + "binary", + "--catalog-cache-peers", + "http://peer1:8080,http://peer2:9090,http://peer3:9091", + ]); + let err = config.peers().unwrap_err(); + assert!(matches!(err, Error::InvalidPeers), "{err}"); + + config.hostname = Some(Host::parse("peer3").unwrap()); + let peers = config.peers().unwrap(); + assert_eq!(peers, [peer1.clone(), peer2.clone()]); + } +} diff --git a/clap_blocks/src/catalog_dsn.rs b/clap_blocks/src/catalog_dsn.rs new file mode 100644 index 0000000..74e84bc --- /dev/null +++ b/clap_blocks/src/catalog_dsn.rs @@ -0,0 +1,176 @@ +//! Catalog-DSN-related configs. +use http::uri::InvalidUri; +use iox_catalog::grpc::client::GrpcCatalogClient; +use iox_catalog::sqlite::{SqliteCatalog, SqliteConnectionOptions}; +use iox_catalog::{ + interface::Catalog, + mem::MemCatalog, + postgres::{PostgresCatalog, PostgresConnectionOptions}, +}; +use iox_time::TimeProvider; +use observability_deps::tracing::*; +use snafu::{ResultExt, Snafu}; +use std::{sync::Arc, time::Duration}; + +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +pub enum Error { + #[snafu(display("Unknown Catalog DSN {dsn}. Expected a string like 'postgresql://postgres@localhost:5432/postgres' or 'sqlite:///tmp/catalog.sqlite'"))] + UnknownCatalogDsn { dsn: String }, + + #[snafu(display("Catalog DSN not specified. Expected a string like 'postgresql://postgres@localhost:5432/postgres' or 'sqlite:///tmp/catalog.sqlite'"))] + DsnNotSpecified {}, + + #[snafu(display("Invalid URI: {source}"))] + InvalidUri { source: InvalidUri }, + + #[snafu(display("A catalog error occurred: {}", source))] + Catalog { + source: iox_catalog::interface::Error, + }, +} + +fn default_max_connections() -> &'static str { + let s = PostgresConnectionOptions::DEFAULT_MAX_CONNS.to_string(); + Box::leak(Box::new(s)) +} + +fn default_connect_timeout() -> &'static str { + let s = + humantime::format_duration(PostgresConnectionOptions::DEFAULT_CONNECT_TIMEOUT).to_string(); + Box::leak(Box::new(s)) +} + +fn default_idle_timeout() -> &'static str { + let s = humantime::format_duration(PostgresConnectionOptions::DEFAULT_IDLE_TIMEOUT).to_string(); + Box::leak(Box::new(s)) +} + +fn default_hotswap_poll_interval_timeout() -> &'static str { + let s = humantime::format_duration(PostgresConnectionOptions::DEFAULT_HOTSWAP_POLL_INTERVAL) + .to_string(); + Box::leak(Box::new(s)) +} + +/// CLI config for catalog DSN. +#[derive(Debug, Clone, Default, clap::Parser)] +pub struct CatalogDsnConfig { + /// Catalog connection string. + /// + /// The dsn determines the type of catalog used. + /// + /// PostgreSQL: `postgresql://postgres@localhost:5432/postgres` + /// + /// Sqlite (a local filename /tmp/foo.sqlite): `sqlite:///tmp/foo.sqlite` - + /// note sqlite is for development/testing only and should not be used for + /// production workloads. + /// + /// Memory (ephemeral, only useful for testing): `memory` + /// + #[clap(long = "catalog-dsn", env = "INFLUXDB_IOX_CATALOG_DSN", action)] + pub dsn: Option, + + /// Maximum number of connections allowed to the catalog at any one time. + #[clap( + long = "catalog-max-connections", + env = "INFLUXDB_IOX_CATALOG_MAX_CONNECTIONS", + default_value = default_max_connections(), + action, + )] + pub max_catalog_connections: u32, + + /// Schema name for PostgreSQL-based catalogs. + #[clap( + long = "catalog-postgres-schema-name", + env = "INFLUXDB_IOX_CATALOG_POSTGRES_SCHEMA_NAME", + default_value = PostgresConnectionOptions::DEFAULT_SCHEMA_NAME, + action, + )] + pub postgres_schema_name: String, + + /// Set the amount of time to attempt connecting to the database. + #[clap( + long = "catalog-connect-timeout", + env = "INFLUXDB_IOX_CATALOG_CONNECT_TIMEOUT", + default_value = default_connect_timeout(), + value_parser = humantime::parse_duration, + )] + pub connect_timeout: Duration, + + /// Set a maximum idle duration for individual connections. + #[clap( + long = "catalog-idle-timeout", + env = "INFLUXDB_IOX_CATALOG_IDLE_TIMEOUT", + default_value = default_idle_timeout(), + value_parser = humantime::parse_duration, + )] + pub idle_timeout: Duration, + + /// If the DSN points to a file (i.e. starts with `dsn-file://`), this sets the interval how often the the file + /// should be polled for updates. + /// + /// If an update is encountered, the underlying connection pool will be hot-swapped. + #[clap( + long = "catalog-hotswap-poll-interval", + env = "INFLUXDB_IOX_CATALOG_HOTSWAP_POLL_INTERVAL", + default_value = default_hotswap_poll_interval_timeout(), + value_parser = humantime::parse_duration, + )] + pub hotswap_poll_interval: Duration, +} + +impl CatalogDsnConfig { + /// Get config-dependent catalog. + pub async fn get_catalog( + &self, + app_name: &'static str, + metrics: Arc, + time_provider: Arc, + ) -> Result, Error> { + let Some(dsn) = self.dsn.as_ref() else { + return Err(Error::DsnNotSpecified {}); + }; + + if dsn.starts_with("postgres") || dsn.starts_with("dsn-file://") { + // do not log entire postgres dsn as it may contain credentials + info!(postgres_schema_name=%self.postgres_schema_name, "Catalog: Postgres"); + let options = PostgresConnectionOptions { + app_name: app_name.to_string(), + schema_name: self.postgres_schema_name.clone(), + dsn: dsn.clone(), + max_conns: self.max_catalog_connections, + connect_timeout: self.connect_timeout, + idle_timeout: self.idle_timeout, + hotswap_poll_interval: self.hotswap_poll_interval, + }; + Ok(Arc::new( + PostgresCatalog::connect(options, metrics) + .await + .context(CatalogSnafu)?, + )) + } else if dsn == "memory" { + info!("Catalog: In-memory"); + let mem = MemCatalog::new(metrics, time_provider); + Ok(Arc::new(mem)) + } else if let Some(file_path) = dsn.strip_prefix("sqlite://") { + info!(file_path, "Catalog: Sqlite"); + let options = SqliteConnectionOptions { + file_path: file_path.to_string(), + }; + Ok(Arc::new( + SqliteCatalog::connect(options, metrics) + .await + .context(CatalogSnafu)?, + )) + } else if dsn.starts_with("http://") || dsn.starts_with("https://") { + info!("Catalog: gRPC"); + let uri = dsn.parse().context(InvalidUriSnafu)?; + let grpc = GrpcCatalogClient::new(uri, metrics, time_provider); + Ok(Arc::new(grpc)) + } else { + Err(Error::UnknownCatalogDsn { + dsn: dsn.to_string(), + }) + } + } +} diff --git a/clap_blocks/src/compactor.rs b/clap_blocks/src/compactor.rs new file mode 100644 index 0000000..9b63bc8 --- /dev/null +++ b/clap_blocks/src/compactor.rs @@ -0,0 +1,156 @@ +//! CLI config for compactor-related commands + +use std::num::NonZeroUsize; + +use crate::{gossip::GossipConfig, memory_size::MemorySize}; + +use super::compactor_scheduler::CompactorSchedulerConfig; + +/// CLI config for compactor +#[derive(Debug, Clone, clap::Parser)] +pub struct CompactorConfig { + /// Gossip config. + #[clap(flatten)] + pub gossip_config: GossipConfig, + + /// Configuration for the compactor scheduler + #[clap(flatten)] + pub compactor_scheduler_config: CompactorSchedulerConfig, + + /// Number of partitions that should be compacted in parallel. + /// + /// This should usually be larger than the compaction job + /// concurrency since one partition can spawn multiple compaction + /// jobs. + #[clap( + long = "compaction-partition-concurrency", + env = "INFLUXDB_IOX_COMPACTION_PARTITION_CONCURRENCY", + default_value = "100", + action + )] + pub compaction_partition_concurrency: NonZeroUsize, + + /// Number of concurrent compaction jobs scheduled to DataFusion. + /// + /// This should usually be smaller than the partition concurrency + /// since one partition can spawn multiple DF compaction jobs. + #[clap( + long = "compaction-df-concurrency", + env = "INFLUXDB_IOX_COMPACTION_DF_CONCURRENCY", + default_value = "10", + action + )] + pub compaction_df_concurrency: NonZeroUsize, + + /// Number of jobs PER PARTITION that move files in and out of the + /// scratchpad. + #[clap( + long = "compaction-partition-scratchpad-concurrency", + env = "INFLUXDB_IOX_COMPACTION_PARTITION_SCRATCHPAD_CONCURRENCY", + default_value = "10", + action + )] + pub compaction_partition_scratchpad_concurrency: NonZeroUsize, + + /// Number of threads to use for the compactor query execution, + /// compaction and persistence. + /// If not specified, defaults to one less than the number of cores on the system + #[clap( + long = "query-exec-thread-count", + env = "INFLUXDB_IOX_QUERY_EXEC_THREAD_COUNT", + action + )] + pub query_exec_thread_count: Option, + + /// Size of memory pool used during compaction plan execution, in + /// bytes. + /// + /// If compaction plans attempt to allocate more than this many + /// bytes during execution, they will error with + /// "ResourcesExhausted". + /// + /// Can be given as absolute value or in percentage of the total available memory (e.g. `10%`). + #[clap( + long = "exec-mem-pool-bytes", + env = "INFLUXDB_IOX_EXEC_MEM_POOL_BYTES", + default_value = "17179869184", // 16GB + action + )] + pub exec_mem_pool_bytes: MemorySize, + + /// Overrides INFLUXDB_IOX_EXEC_MEM_POOL_BYTES to set the size of memory pool + /// used during compaction DF plan execution. This value is expressed as a percent + /// of the memory limit for the cgroup (e.g. 70 = 70% of the cgroup memory limit). + /// This is converted to a byte limit as the compactor starts. + /// + /// Extreme values (<20% or >90%) are ignored and INFLUXDB_IOX_EXEC_MEM_POOL_BYTES + /// is used. It will also use INFLUXDB_IOX_EXEC_MEM_POOL_BYTES if we fail to read + /// the cgroup limit, or it doesn't parse to a sane value. + /// + /// If compaction plans attempt to allocate more than the computed byte limit + /// during execution, they will error with "ResourcesExhausted". + #[clap( + long = "exec-mem-pool-percent", + env = "INFLUXDB_IOX_EXEC_MEM_POOL_PERCENT", + default_value = "70", + action + )] + pub exec_mem_pool_percent: u64, + + /// Maximum duration of the per-partition compaction task in seconds. + #[clap( + long = "compaction-partition-timeout-secs", + env = "INFLUXDB_IOX_COMPACTION_PARTITION_TIMEOUT_SECS", + default_value = "1800", + action + )] + pub partition_timeout_secs: u64, + + /// Shadow mode. + /// + /// This will NOT write / commit any output to the object store or catalog. + /// + /// This is mostly useful for debugging. + #[clap( + long = "compaction-shadow-mode", + env = "INFLUXDB_IOX_COMPACTION_SHADOW_MODE", + action + )] + pub shadow_mode: bool, + + /// Enable scratchpad. + /// + /// This allows disabling the scratchpad in production. + /// + /// Disabling this is useful for testing performance and memory consequences of the scratchpad. + #[clap( + long = "compaction-enable-scratchpad", + env = "INFLUXDB_IOX_COMPACTION_ENABLE_SCRATCHPAD", + default_value = "true", + action + )] + pub enable_scratchpad: bool, + + /// Only process all discovered partitions once. + /// + /// By default the compactor will continuously loop over all + /// partitions looking for work. Setting this option results in + /// exiting the loop after the one iteration. + #[clap( + long = "compaction-process-once", + env = "INFLUXDB_IOX_COMPACTION_PROCESS_ONCE", + action + )] + pub process_once: bool, + + /// Limit the number of partition fetch queries to at most the specified + /// number of queries per second. + /// + /// Queries are smoothed over the full second. + #[clap( + long = "max-partition-fetch-queries-per-second", + env = "INFLUXDB_IOX_MAX_PARTITION_FETCH_QUERIES_PER_SECOND", + action + )] + pub max_partition_fetch_queries_per_second: Option, +} diff --git a/clap_blocks/src/compactor_scheduler.rs b/clap_blocks/src/compactor_scheduler.rs new file mode 100644 index 0000000..e2b3c8f --- /dev/null +++ b/clap_blocks/src/compactor_scheduler.rs @@ -0,0 +1,351 @@ +//! Compactor-Scheduler-related configs. + +use crate::socket_addr::SocketAddr; +use std::str::FromStr; + +/// Compaction Scheduler type. +#[derive(Debug, Default, Clone, Copy, PartialEq, clap::ValueEnum)] +pub enum CompactorSchedulerType { + /// Perform scheduling decisions locally. + #[default] + Local, + + /// Perform scheduling decisions remotely. + Remote, +} + +/// CLI config for compactor scheduler. +#[derive(Debug, Clone, Default, clap::Parser)] +pub struct ShardConfigForLocalScheduler { + /// Number of shards. + /// + /// If this is set then the shard ID MUST also be set. If both are not provided, sharding is disabled. + /// (shard ID can be provided by the host name) + #[clap( + long = "compaction-shard-count", + env = "INFLUXDB_IOX_COMPACTION_SHARD_COUNT", + action + )] + pub shard_count: Option, + + /// Shard ID. + /// + /// Starts at 0, must be smaller than the number of shard. + /// + /// If this is set then the shard count MUST also be set. If both are not provided, sharding is disabled. + #[clap( + long = "compaction-shard-id", + env = "INFLUXDB_IOX_COMPACTION_SHARD_ID", + requires("shard_count"), + action + )] + pub shard_id: Option, + + /// Host Name + /// + /// comprised of leading text (e.g. 'iox-shared-compactor-'), ending with shard_id (e.g. '0'). + /// When shard_count is specified, but shard_id is not specified, the id is extracted from hostname. + #[clap(env = "HOSTNAME")] + pub hostname: Option, +} + +/// CLI config for partitions_source used by the scheduler. +#[derive(Debug, Clone, Default, clap::Parser)] +pub struct PartitionSourceConfigForLocalScheduler { + /// The compactor will only consider compacting partitions that + /// have new Parquet files created within this many minutes. + #[clap( + long = "compaction_partition_minute_threshold", + env = "INFLUXDB_IOX_COMPACTION_PARTITION_MINUTE_THRESHOLD", + default_value = "10", + action + )] + pub compaction_partition_minute_threshold: u64, + + /// Filter partitions to the given set of IDs. + /// + /// This is mostly useful for debugging. + #[clap( + long = "compaction-partition-filter", + env = "INFLUXDB_IOX_COMPACTION_PARTITION_FILTER", + action + )] + pub partition_filter: Option>, + + /// Compact all partitions found in the catalog, no matter if/when + /// they received writes. + #[clap( + long = "compaction-process-all-partitions", + env = "INFLUXDB_IOX_COMPACTION_PROCESS_ALL_PARTITIONS", + default_value = "false", + action + )] + pub process_all_partitions: bool, + + /// Ignores "partition marked w/ error and shall be skipped" entries in the catalog. + /// + /// This is mostly useful for debugging. + #[clap( + long = "compaction-ignore-partition-skip-marker", + env = "INFLUXDB_IOX_COMPACTION_IGNORE_PARTITION_SKIP_MARKER", + action + )] + pub ignore_partition_skip_marker: bool, +} + +/// CLI config for scheduler's gossip. +#[derive(Debug, Clone, clap::Parser)] +pub struct CompactorSchedulerGossipConfig { + /// A comma-delimited set of seed gossip peer addresses. + /// + /// Example: "10.0.0.1:4242,10.0.0.2:4242" + /// + /// These seeds will be used to discover all other peers that talk to the + /// same seeds. Typically all nodes in the cluster should use the same set + /// of seeds. + #[clap( + long = "compactor-scheduler-gossip-seed-list", + env = "INFLUXDB_IOX_COMPACTOR_SCHEDULER_GOSSIP_SEED_LIST", + required = false, + num_args=1.., + value_delimiter = ',', + requires = "scheduler_gossip_bind_address", // Field name, not flag + )] + pub scheduler_seed_list: Vec, + + /// The UDP socket address IOx will use for gossip communication between + /// peers. + /// + /// Example: "0.0.0.0:4242" + /// + /// If not provided, the gossip sub-system is disabled. + #[clap( + long = "compactor-scheduler-gossip-bind-address", + env = "INFLUXDB_IOX_COMPACTOR_SCHEDULER_GOSSIP_BIND_ADDR", + default_value = "0.0.0.0:0", + required = false, + action + )] + pub scheduler_gossip_bind_address: SocketAddr, +} + +impl Default for CompactorSchedulerGossipConfig { + fn default() -> Self { + Self { + scheduler_seed_list: vec![], + scheduler_gossip_bind_address: SocketAddr::from_str("0.0.0.0:4324").unwrap(), + } + } +} + +impl CompactorSchedulerGossipConfig { + /// constructor for GossipConfig + /// + pub fn new(bind_address: &str, seed_list: Vec) -> Self { + Self { + scheduler_seed_list: seed_list, + scheduler_gossip_bind_address: SocketAddr::from_str(bind_address).unwrap(), + } + } +} + +/// CLI config for compactor scheduler. +#[derive(Debug, Clone, Default, clap::Parser)] +pub struct CompactorSchedulerConfig { + /// Scheduler type to use. + #[clap( + value_enum, + long = "compactor-scheduler", + env = "INFLUXDB_IOX_COMPACTION_SCHEDULER", + default_value = "local", + action + )] + pub compactor_scheduler_type: CompactorSchedulerType, + + /// Maximum number of files that the compactor will try and + /// compact in a single plan. + /// + /// The higher this setting is the fewer compactor plans are run + /// and thus fewer resources over time are consumed by the + /// compactor. Increasing this setting also increases the peak + /// memory used for each compaction plan, and thus if it is set + /// too high, the compactor plans may exceed available memory. + #[clap( + long = "compaction-max-num-files-per-plan", + env = "INFLUXDB_IOX_COMPACTION_MAX_NUM_FILES_PER_PLAN", + default_value = "20", + action + )] + pub max_num_files_per_plan: usize, + + /// Desired max size of compacted parquet files. + /// + /// Note this is a target desired value, rather than a guarantee. + /// 1024 * 1024 * 100 = 104,857,600 + #[clap( + long = "compaction-max-desired-size-bytes", + env = "INFLUXDB_IOX_COMPACTION_MAX_DESIRED_FILE_SIZE_BYTES", + default_value = "104857600", + action + )] + pub max_desired_file_size_bytes: u64, + + /// Minimum number of L1 files to compact to L2. + /// + /// If there are more than this many L1 (by definition non + /// overlapping) files in a partition, the compactor will compact + /// them together into one or more larger L2 files. + /// + /// Setting this value higher in general results in fewer overall + /// resources spent on compaction but more files per partition (and + /// thus less optimal compression and query performance). + #[clap( + long = "compaction-min-num-l1-files-to-compact", + env = "INFLUXDB_IOX_COMPACTION_MIN_NUM_L1_FILES_TO_COMPACT", + default_value = "10", + action + )] + pub min_num_l1_files_to_compact: usize, + + /// Maximum number of columns in a table of a partition that + /// will be able to considered to get compacted + /// + /// If a table has more than this many columns, the compactor will + /// not compact it, to avoid large memory use. + #[clap( + long = "compaction-max-num-columns-per-table", + env = "INFLUXDB_IOX_COMPACTION_MAX_NUM_COLUMNS_PER_TABLE", + default_value = "10000", + action + )] + pub max_num_columns_per_table: usize, + + /// Percentage of desired max file size for "leading edge split" + /// optimization. + /// + /// This setting controls the estimated output file size at which + /// the compactor will apply the "leading edge" optimization. + /// + /// When compacting files together, if the output size is + /// estimated to be greater than the following quantity, the + /// "leading edge split" optimization will be applied: + /// + /// percentage_max_file_size * target_file_size + /// + /// This value must be between (0, 100) + /// + /// Default is 20 + #[clap( + long = "compaction-percentage-max-file_size", + env = "INFLUXDB_IOX_COMPACTION_PERCENTAGE_MAX_FILE_SIZE", + default_value = "20", + action + )] + pub percentage_max_file_size: u16, + + /// Enable new priority-based compaction selection. + /// + /// Eventually, this will be the only way to select partitions. + /// + /// Default is false + #[clap( + long = "compaction-priority-based-selection", + env = "INFLUXDB_IOX_COMPACTION_PRIORITY_BASED_SELECTION", + default_value = "false", + action + )] + pub priority_based_selection: bool, + + /// Split file percentage for "leading edge split" + /// + /// To reduce the likelihood of recompacting the same data too many + /// times, the compactor uses the "leading edge split" + /// optimization for the common case where the new data written + /// into a partition also has the most recent timestamps. + /// + /// When compacting multiple files together, if the compactor + /// estimates the resulting file will be large enough (see + /// `percentage_max_file_size`) it creates two output files + /// rather than one, split by time, like this: + /// + /// `|-------------- older_data -----------------||---- newer_data ----|` + /// + /// In the common case, the file containing `older_data` is less + /// likely to overlap with new data written in. + /// + /// This setting controls what percentage of data is placed into + /// the `older_data` portion. + /// + /// Increasing this value increases the average size of compacted + /// files after the first round of compaction. However, doing so + /// also increase the likelihood that late arriving data will + /// overlap with larger existing files, necessitating additional + /// compaction rounds. + /// + /// This value must be between (0, 100) + #[clap( + long = "compaction-split-percentage", + env = "INFLUXDB_IOX_COMPACTION_SPLIT_PERCENTAGE", + default_value = "80", + action + )] + pub split_percentage: u16, + + /// Partition source config used by the local scheduler. + #[clap(flatten)] + pub partition_source_config: PartitionSourceConfigForLocalScheduler, + + /// Shard config used by the local scheduler. + #[clap(flatten)] + pub shard_config: ShardConfigForLocalScheduler, + + /// Gossip config. + #[clap(flatten)] + pub gossip_config: CompactorSchedulerGossipConfig, +} + +#[cfg(test)] +mod tests { + use super::*; + use clap::Parser; + use test_helpers::assert_contains; + + #[test] + fn default_compactor_scheduler_type_is_local() { + let config = CompactorSchedulerConfig::try_parse_from(["my_binary"]).unwrap(); + assert_eq!( + config.compactor_scheduler_type, + CompactorSchedulerType::Local + ); + } + + #[test] + fn can_specify_local() { + let config = CompactorSchedulerConfig::try_parse_from([ + "my_binary", + "--compactor-scheduler", + "local", + ]) + .unwrap(); + assert_eq!( + config.compactor_scheduler_type, + CompactorSchedulerType::Local + ); + } + + #[test] + fn any_other_scheduler_type_string_is_invalid() { + let error = CompactorSchedulerConfig::try_parse_from([ + "my_binary", + "--compactor-scheduler", + "hello", + ]) + .unwrap_err() + .to_string(); + assert_contains!( + &error, + "invalid value 'hello' for '--compactor-scheduler '" + ); + assert_contains!(&error, "[possible values: local, remote]"); + } +} diff --git a/clap_blocks/src/garbage_collector.rs b/clap_blocks/src/garbage_collector.rs new file mode 100644 index 0000000..0b10d78 --- /dev/null +++ b/clap_blocks/src/garbage_collector.rs @@ -0,0 +1,150 @@ +//! Garbage Collector configuration +use clap::Parser; +use humantime::parse_duration; +use std::{fmt::Debug, time::Duration}; + +/// Configuration specific to the object store garbage collector +#[derive(Debug, Clone, Parser, Copy)] +pub struct GarbageCollectorConfig { + /// If this flag is specified, don't delete the files in object storage. Only print the files + /// that would be deleted if this flag wasn't specified. + #[clap(long, env = "INFLUXDB_IOX_GC_DRY_RUN")] + pub dry_run: bool, + + /// Items in the object store that are older than this duration that are not referenced in the + /// catalog will be deleted. + /// Parsed with + /// + /// If not specified, defaults to 14 days ago. + #[clap( + long, + default_value = "14d", + value_parser = parse_duration, + env = "INFLUXDB_IOX_GC_OBJECTSTORE_CUTOFF" + )] + pub objectstore_cutoff: Duration, + + /// Number of minutes to sleep between iterations of the objectstore list loop. + /// This is the sleep between entirely fresh list operations. + /// Defaults to 30 minutes. + #[clap( + long, + default_value_t = 30, + env = "INFLUXDB_IOX_GC_OBJECTSTORE_SLEEP_INTERVAL_MINUTES" + )] + pub objectstore_sleep_interval_minutes: u64, + + /// Number of milliseconds to sleep between listing consecutive chunks of objecstore files. + /// Object store listing is processed in batches; this is the sleep between batches. + /// Defaults to 1000 milliseconds. + #[clap( + long, + default_value_t = 1000, + env = "INFLUXDB_IOX_GC_OBJECTSTORE_SLEEP_INTERVAL_BATCH_MILLISECONDS" + )] + pub objectstore_sleep_interval_batch_milliseconds: u64, + + /// Parquet file rows in the catalog flagged for deletion before this duration will be deleted. + /// Parsed with + /// + /// If not specified, defaults to 14 days ago. + #[clap( + long, + default_value = "14d", + value_parser = parse_duration, + env = "INFLUXDB_IOX_GC_PARQUETFILE_CUTOFF" + )] + pub parquetfile_cutoff: Duration, + + /// Number of minutes to sleep between iterations of the parquet file deletion loop. + /// + /// Defaults to 30 minutes. + /// + /// If both INFLUXDB_IOX_GC_PARQUETFILE_SLEEP_INTERVAL_MINUTES and + /// INFLUXDB_IOX_GC_PARQUETFILE_SLEEP_INTERVAL are specified, the smaller is chosen + #[clap(long, env = "INFLUXDB_IOX_GC_PARQUETFILE_SLEEP_INTERVAL_MINUTES")] + pub parquetfile_sleep_interval_minutes: Option, + + /// Duration to sleep between iterations of the parquet file deletion loop. + /// + /// Defaults to 30 minutes. + /// + /// If both INFLUXDB_IOX_GC_PARQUETFILE_SLEEP_INTERVAL_MINUTES and + /// INFLUXDB_IOX_GC_PARQUETFILE_SLEEP_INTERVAL are specified, the smaller is chosen + #[clap( + long, + value_parser = parse_duration, + env = "INFLUXDB_IOX_GC_PARQUETFILE_SLEEP_INTERVAL" + )] + pub parquetfile_sleep_interval: Option, + + /// Number of minutes to sleep between iterations of the retention code. + /// Defaults to 35 minutes to reduce incidence of it running at the same time as the parquet + /// file deleter. + #[clap( + long, + default_value_t = 35, + env = "INFLUXDB_IOX_GC_RETENTION_SLEEP_INTERVAL_MINUTES" + )] + pub retention_sleep_interval_minutes: u64, +} + +impl GarbageCollectorConfig { + /// Returns the parquet_file sleep interval + pub fn parquetfile_sleep_interval(&self) -> Duration { + match ( + self.parquetfile_sleep_interval, + self.parquetfile_sleep_interval_minutes, + ) { + (None, None) => Duration::from_secs(30 * 60), + (Some(d), None) => d, + (None, Some(m)) => Duration::from_secs(m * 60), + (Some(d), Some(m)) => d.min(Duration::from_secs(m * 60)), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_gc_config() { + let a: &[&str] = &[]; + let config = GarbageCollectorConfig::parse_from(a); + assert_eq!( + config.parquetfile_sleep_interval(), + Duration::from_secs(30 * 60) + ); + + let config = + GarbageCollectorConfig::parse_from(["something", "--parquetfile-sleep-interval", "3d"]); + + assert_eq!( + config.parquetfile_sleep_interval(), + Duration::from_secs(24 * 60 * 60 * 3) + ); + + let config = GarbageCollectorConfig::parse_from([ + "something", + "--parquetfile-sleep-interval-minutes", + "34", + ]); + assert_eq!( + config.parquetfile_sleep_interval(), + Duration::from_secs(34 * 60) + ); + + let config = GarbageCollectorConfig::parse_from([ + "something", + "--parquetfile-sleep-interval-minutes", + "34", + "--parquetfile-sleep-interval", + "35m", + ]); + assert_eq!( + config.parquetfile_sleep_interval(), + Duration::from_secs(34 * 60) + ); + } +} diff --git a/clap_blocks/src/gossip.rs b/clap_blocks/src/gossip.rs new file mode 100644 index 0000000..47365ba --- /dev/null +++ b/clap_blocks/src/gossip.rs @@ -0,0 +1,52 @@ +//! CLI config for cluster gossip communication. + +use crate::socket_addr::SocketAddr; +use std::str::FromStr; + +/// Configuration parameters for the cluster gossip communication mechanism. +#[derive(Debug, Clone, clap::Parser)] +#[allow(missing_copy_implementations)] +pub struct GossipConfig { + /// A comma-delimited set of seed gossip peer addresses. + /// + /// Example: "10.0.0.1:4242,10.0.0.2:4242" + /// + /// These seeds will be used to discover all other peers that talk to the + /// same seeds. Typically all nodes in the cluster should use the same set + /// of seeds. + #[clap( + long = "gossip-seed-list", + env = "INFLUXDB_IOX_GOSSIP_SEED_LIST", + required = false, + num_args=1.., + value_delimiter = ',', + requires = "gossip_bind_address", // Field name, not flag + )] + pub seed_list: Vec, + + /// The UDP socket address IOx will use for gossip communication between + /// peers. + /// + /// Example: "0.0.0.0:4242" + /// + /// If not provided, the gossip sub-system is disabled. + #[clap( + long = "gossip-bind-address", + env = "INFLUXDB_IOX_GOSSIP_BIND_ADDR", + default_value = "0.0.0.0:4242", + required = false, + action + )] + pub gossip_bind_address: SocketAddr, +} + +impl GossipConfig { + /// constructor for GossipConfig + /// + pub fn new(bind_address: &str, seed_list: Vec) -> Self { + Self { + seed_list, + gossip_bind_address: SocketAddr::from_str(bind_address).unwrap(), + } + } +} diff --git a/clap_blocks/src/ingester.rs b/clap_blocks/src/ingester.rs new file mode 100644 index 0000000..be2ab26 --- /dev/null +++ b/clap_blocks/src/ingester.rs @@ -0,0 +1,101 @@ +//! CLI config for the ingester using the RPC write path + +use std::{num::NonZeroUsize, path::PathBuf}; + +use crate::gossip::GossipConfig; + +/// CLI config for the ingester using the RPC write path +#[derive(Debug, Clone, clap::Parser)] +#[allow(missing_copy_implementations)] +pub struct IngesterConfig { + /// Gossip config. + #[clap(flatten)] + pub gossip_config: GossipConfig, + + /// Where this ingester instance should store its write-ahead log files. Each ingester instance + /// must have its own directory. + #[clap(long = "wal-directory", env = "INFLUXDB_IOX_WAL_DIRECTORY", action)] + pub wal_directory: PathBuf, + + /// Specify the maximum allowed incoming RPC write message size sent by the + /// Router. + #[clap( + long = "rpc-write-max-incoming-bytes", + env = "INFLUXDB_IOX_RPC_WRITE_MAX_INCOMING_BYTES", + default_value = "104857600", // 100MiB + )] + pub rpc_write_max_incoming_bytes: usize, + + /// The number of seconds between WAL file rotations. + #[clap( + long = "wal-rotation-period-seconds", + env = "INFLUXDB_IOX_WAL_ROTATION_PERIOD_SECONDS", + default_value = "300", + action + )] + pub wal_rotation_period_seconds: u64, + + /// Sets how many queries the ingester will handle simultaneously before + /// rejecting further incoming requests. + #[clap( + long = "concurrent-query-limit", + env = "INFLUXDB_IOX_CONCURRENT_QUERY_LIMIT", + default_value = "20", + action + )] + pub concurrent_query_limit: usize, + + /// The maximum number of persist tasks that can run simultaneously. + #[clap( + long = "persist-max-parallelism", + env = "INFLUXDB_IOX_PERSIST_MAX_PARALLELISM", + default_value = "5", + action + )] + pub persist_max_parallelism: usize, + + /// The maximum number of persist tasks that can be queued at any one time. + /// + /// Once this limit is reached, ingest is blocked until the persist backlog + /// is reduced. + #[clap( + long = "persist-queue-depth", + env = "INFLUXDB_IOX_PERSIST_QUEUE_DEPTH", + default_value = "250", + action + )] + pub persist_queue_depth: usize, + + /// The limit at which a partition's estimated persistence cost causes it to + /// be queued for persistence. + #[clap( + long = "persist-hot-partition-cost", + env = "INFLUXDB_IOX_PERSIST_HOT_PARTITION_COST", + default_value = "20000000", // 20,000,000 + action + )] + pub persist_hot_partition_cost: usize, + + /// An optional lower bound byte size limit that buffered data within a + /// partition must reach in order to be converted into an incremental + /// snapshot at query time. + /// + /// Snapshots improve query performance by amortising response generation at + /// the expense of a small memory overhead. Snapshots are retained until the + /// buffer is persisted. + #[clap( + long = "min-partition-snapshot-size", + env = "INFLUXDB_IOX_MIN_PARTITION_SNAPSHOT_SIZE" + )] + pub min_partition_snapshot_size: Option, + + /// Limit the number of partitions that may be buffered in a single + /// namespace (across all tables) at any one time. + /// + /// This limit is disabled by default. + #[clap( + long = "max-partitions-per-namespace", + env = "INFLUXDB_IOX_MAX_PARTITIONS_PER_NAMESPACE" + )] + pub max_partitions_per_namespace: Option, +} diff --git a/clap_blocks/src/ingester_address.rs b/clap_blocks/src/ingester_address.rs new file mode 100644 index 0000000..90a8e8d --- /dev/null +++ b/clap_blocks/src/ingester_address.rs @@ -0,0 +1,308 @@ +//! Shared configuration and tests for accepting ingester addresses as arguments. + +use http::uri::{InvalidUri, InvalidUriParts, Uri}; +use snafu::{ResultExt, Snafu}; +use std::{fmt::Display, str::FromStr}; + +/// An address to an ingester's gRPC API. Create by using `IngesterAddress::from_str`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct IngesterAddress { + uri: Uri, +} + +/// Why a specified ingester address might be invalid +#[allow(missing_docs)] +#[derive(Snafu, Debug)] +pub enum Error { + #[snafu(display("{source}"))] + Invalid { source: InvalidUri }, + + #[snafu(display("Port is required; no port found in `{value}`"))] + MissingPort { value: String }, + + #[snafu(context(false))] + InvalidParts { source: InvalidUriParts }, +} + +impl FromStr for IngesterAddress { + type Err = Error; + + fn from_str(s: &str) -> Result { + let uri = Uri::from_str(s).context(InvalidSnafu)?; + + if uri.port().is_none() { + return MissingPortSnafu { value: s }.fail(); + } + + let uri = if uri.scheme().is_none() { + Uri::from_str(&format!("http://{s}")).context(InvalidSnafu)? + } else { + uri + }; + + Ok(Self { uri }) + } +} + +impl Display for IngesterAddress { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.uri) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use clap::{error::ErrorKind, Parser}; + use std::env; + use test_helpers::{assert_contains, assert_error}; + + /// Applications such as the router MUST have valid ingester addresses. + #[derive(Debug, Clone, clap::Parser)] + struct RouterConfig { + #[clap( + long = "ingester-addresses", + env = "TEST_INFLUXDB_IOX_INGESTER_ADDRESSES", + required = true, + num_args=1.., + value_delimiter = ',' + )] + pub(crate) ingester_addresses: Vec, + } + + #[test] + fn error_if_not_specified_when_required() { + assert_error!( + RouterConfig::try_parse_from(["my_binary"]), + ref e if e.kind() == ErrorKind::MissingRequiredArgument + ); + } + + /// Applications such as the querier might not have any ingester addresses, but if they have + /// any, they should be valid. + #[derive(Debug, Clone, clap::Parser)] + struct QuerierConfig { + #[clap( + long = "ingester-addresses", + env = "TEST_INFLUXDB_IOX_INGESTER_ADDRESSES", + required = false, + num_args=0.., + value_delimiter = ',' + )] + pub(crate) ingester_addresses: Vec, + } + + #[test] + fn empty_if_not_specified_when_optional() { + assert!(QuerierConfig::try_parse_from(["my_binary"]) + .unwrap() + .ingester_addresses + .is_empty()); + } + + fn both_types_valid(args: &[&'static str], expected: &[&'static str]) { + let router = RouterConfig::try_parse_from(args).unwrap(); + let actual: Vec<_> = router + .ingester_addresses + .iter() + .map(ToString::to_string) + .collect(); + assert_eq!(actual, expected); + + let querier = QuerierConfig::try_parse_from(args).unwrap(); + let actual: Vec<_> = querier + .ingester_addresses + .iter() + .map(ToString::to_string) + .collect(); + assert_eq!(actual, expected); + } + + fn both_types_error(args: &[&'static str], expected_error_message: &'static str) { + assert_contains!( + RouterConfig::try_parse_from(args).unwrap_err().to_string(), + expected_error_message + ); + assert_contains!( + QuerierConfig::try_parse_from(args).unwrap_err().to_string(), + expected_error_message + ); + } + + #[test] + fn accepts_one() { + let args = [ + "my_binary", + "--ingester-addresses", + "http://example.com:1234", + ]; + let expected = ["http://example.com:1234/"]; + + both_types_valid(&args, &expected); + } + + #[test] + fn accepts_two() { + let args = [ + "my_binary", + "--ingester-addresses", + "http://example.com:1234,http://example.com:5678", + ]; + let expected = ["http://example.com:1234/", "http://example.com:5678/"]; + + both_types_valid(&args, &expected); + } + + #[test] + fn rejects_any_invalid_uri() { + let args = [ + "my_binary", + "--ingester-addresses", + "http://example.com:1234,", // note the trailing comma; empty URIs are invalid + ]; + let expected = "error: invalid value '' for '--ingester-addresses"; + + both_types_error(&args, expected); + } + + #[test] + fn rejects_uri_without_port() { + let args = [ + "my_binary", + "--ingester-addresses", + "example.com,http://example.com:1234", + ]; + let expected = "Port is required; no port found in `example.com`"; + + both_types_error(&args, expected); + } + + #[test] + fn no_scheme_assumes_http() { + let args = [ + "my_binary", + "--ingester-addresses", + "http://example.com:1234,somescheme://0.0.0.0:1000,127.0.0.1:8080", + ]; + let expected = [ + "http://example.com:1234/", + "somescheme://0.0.0.0:1000/", + "http://127.0.0.1:8080/", + ]; + + both_types_valid(&args, &expected); + } + + #[test] + fn specifying_flag_multiple_times_works() { + let args = [ + "my_binary", + "--ingester-addresses", + "http://example.com:1234", + "--ingester-addresses", + "somescheme://0.0.0.0:1000", + "--ingester-addresses", + "127.0.0.1:8080", + ]; + let expected = [ + "http://example.com:1234/", + "somescheme://0.0.0.0:1000/", + "http://127.0.0.1:8080/", + ]; + + both_types_valid(&args, &expected); + } + + #[test] + fn specifying_flag_multiple_times_and_using_commas_works() { + let args = [ + "my_binary", + "--ingester-addresses", + "http://example.com:1234", + "--ingester-addresses", + "somescheme://0.0.0.0:1000,127.0.0.1:8080", + ]; + let expected = [ + "http://example.com:1234/", + "somescheme://0.0.0.0:1000/", + "http://127.0.0.1:8080/", + ]; + + both_types_valid(&args, &expected); + } + + /// Use an environment variable name not shared with any other config to avoid conflicts when + /// setting the var in tests. + /// Applications such as the router MUST have valid ingester addresses. + #[derive(Debug, Clone, clap::Parser)] + struct EnvRouterConfig { + #[clap( + long = "ingester-addresses", + env = "NO_CONFLICT_ROUTER_TEST_INFLUXDB_IOX_INGESTER_ADDRESSES", + required = true, + num_args=1.., + value_delimiter = ',' + )] + pub(crate) ingester_addresses: Vec, + } + + #[test] + fn required_and_specified_via_environment_variable() { + env::set_var( + "NO_CONFLICT_ROUTER_TEST_INFLUXDB_IOX_INGESTER_ADDRESSES", + "http://example.com:1234,somescheme://0.0.0.0:1000,127.0.0.1:8080", + ); + let args = ["my_binary"]; + let expected = [ + "http://example.com:1234/", + "somescheme://0.0.0.0:1000/", + "http://127.0.0.1:8080/", + ]; + + let router = EnvRouterConfig::try_parse_from(args).unwrap(); + let actual: Vec<_> = router + .ingester_addresses + .iter() + .map(ToString::to_string) + .collect(); + assert_eq!(actual, expected); + } + + /// Use an environment variable name not shared with any other config to avoid conflicts when + /// setting the var in tests. + /// Applications such as the querier might not have any ingester addresses, but if they have + /// any, they should be valid. + #[derive(Debug, Clone, clap::Parser)] + struct EnvQuerierConfig { + #[clap( + long = "ingester-addresses", + env = "NO_CONFLICT_QUERIER_TEST_INFLUXDB_IOX_INGESTER_ADDRESSES", + required = false, + num_args=0.., + value_delimiter = ',' + )] + pub(crate) ingester_addresses: Vec, + } + + #[test] + fn optional_and_specified_via_environment_variable() { + env::set_var( + "NO_CONFLICT_QUERIER_TEST_INFLUXDB_IOX_INGESTER_ADDRESSES", + "http://example.com:1234,somescheme://0.0.0.0:1000,127.0.0.1:8080", + ); + let args = ["my_binary"]; + let expected = [ + "http://example.com:1234/", + "somescheme://0.0.0.0:1000/", + "http://127.0.0.1:8080/", + ]; + + let querier = EnvQuerierConfig::try_parse_from(args).unwrap(); + let actual: Vec<_> = querier + .ingester_addresses + .iter() + .map(ToString::to_string) + .collect(); + assert_eq!(actual, expected); + } +} diff --git a/clap_blocks/src/lib.rs b/clap_blocks/src/lib.rs new file mode 100644 index 0000000..d9f6891 --- /dev/null +++ b/clap_blocks/src/lib.rs @@ -0,0 +1,37 @@ +//! Building blocks for [`clap`]-driven configs. +//! +//! They can easily be re-used using `#[clap(flatten)]`. +#![deny(rustdoc::broken_intra_doc_links, rust_2018_idioms)] +#![warn( + missing_copy_implementations, + missing_docs, + clippy::explicit_iter_loop, + // See https://github.com/influxdata/influxdb_iox/pull/1671 + clippy::future_not_send, + clippy::use_self, + clippy::clone_on_ref_ptr, + clippy::todo, + clippy::dbg_macro, + unused_crate_dependencies +)] + +// Workaround for "unused crate" lint false positives. +use workspace_hack as _; + +pub mod bulk_ingest; +pub mod catalog_cache; +pub mod catalog_dsn; +pub mod compactor; +pub mod compactor_scheduler; +pub mod garbage_collector; +pub mod gossip; +pub mod ingester; +pub mod ingester_address; +pub mod memory_size; +pub mod object_store; +pub mod parquet_cache; +pub mod querier; +pub mod router; +pub mod run_config; +pub mod single_tenant; +pub mod socket_addr; diff --git a/clap_blocks/src/memory_size.rs b/clap_blocks/src/memory_size.rs new file mode 100644 index 0000000..6e7515d --- /dev/null +++ b/clap_blocks/src/memory_size.rs @@ -0,0 +1,113 @@ +//! Helper types to express memory size. + +use std::{str::FromStr, sync::OnceLock}; + +use sysinfo::{MemoryRefreshKind, RefreshKind, System}; + +/// Memory size. +/// +/// # Parsing +/// This can be parsed from strings in one of the following formats: +/// +/// - **absolute:** just use a non-negative number to specify the absolute bytes, e.g. `1024` +/// - **relative:** use percentage between 0 and 100 (both inclusive) to specify a relative amount of the totally +/// available memory size, e.g. `50%` +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct MemorySize(usize); + +impl MemorySize { + /// Number of bytes. + pub fn bytes(&self) -> usize { + self.0 + } +} + +impl std::fmt::Debug for MemorySize { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl std::fmt::Display for MemorySize { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl FromStr for MemorySize { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.strip_suffix('%') { + Some(s) => { + let percentage = u64::from_str(s).map_err(|e| e.to_string())?; + if percentage > 100 { + return Err(format!( + "relative memory size must be in [0, 100] but is {percentage}" + )); + } + let total = total_mem_bytes(); + let bytes = (percentage as f64 / 100f64 * total as f64).round() as usize; + Ok(Self(bytes)) + } + None => { + let bytes = usize::from_str(s).map_err(|e| e.to_string())?; + Ok(Self(bytes)) + } + } + } +} + +/// Totally available memory size in bytes. +pub fn total_mem_bytes() -> usize { + // Keep this in a global state so that we only need to inspect the system once during IOx startup. + static TOTAL_MEM_BYTES: OnceLock = OnceLock::new(); + + *TOTAL_MEM_BYTES.get_or_init(|| { + let sys = System::new_with_specifics( + RefreshKind::new().with_memory(MemoryRefreshKind::everything()), + ); + sys.total_memory() as usize + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse() { + assert_ok("0", 0); + assert_ok("1", 1); + assert_ok("1024", 1024); + assert_ok("0%", 0); + + assert_gt_zero("50%"); + + assert_err("-1", "invalid digit found in string"); + assert_err("foo", "invalid digit found in string"); + assert_err("-1%", "invalid digit found in string"); + assert_err( + "101%", + "relative memory size must be in [0, 100] but is 101", + ); + } + + #[track_caller] + fn assert_ok(s: &'static str, expected: usize) { + let parsed: MemorySize = s.parse().unwrap(); + assert_eq!(parsed.bytes(), expected); + } + + #[track_caller] + fn assert_gt_zero(s: &'static str) { + let parsed: MemorySize = s.parse().unwrap(); + assert!(parsed.bytes() > 0); + } + + #[track_caller] + fn assert_err(s: &'static str, expected: &'static str) { + let err = MemorySize::from_str(s).unwrap_err(); + assert_eq!(err, expected); + } +} diff --git a/clap_blocks/src/object_store.rs b/clap_blocks/src/object_store.rs new file mode 100644 index 0000000..e961357 --- /dev/null +++ b/clap_blocks/src/object_store.rs @@ -0,0 +1,775 @@ +//! CLI handling for object store config (via CLI arguments and environment variables). + +use futures::TryStreamExt; +use non_empty_string::NonEmptyString; +use object_store::{ + memory::InMemory, + path::Path, + throttle::{ThrottleConfig, ThrottledStore}, + DynObjectStore, +}; +use observability_deps::tracing::{info, warn}; +use snafu::{ResultExt, Snafu}; +use std::{convert::Infallible, fs, num::NonZeroUsize, path::PathBuf, sync::Arc, time::Duration}; +use uuid::Uuid; + +use crate::parquet_cache::ParquetCacheClientConfig; + +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +pub enum ParseError { + #[snafu(display("Unable to create database directory {:?}: {}", path, source))] + CreatingDatabaseDirectory { + path: PathBuf, + source: std::io::Error, + }, + + #[snafu(display("Unable to create local store {:?}: {}", path, source))] + CreateLocalFileSystem { + path: PathBuf, + source: object_store::Error, + }, + + #[snafu(display( + "Specified {:?} for the object store, required configuration missing for {}", + object_store, + missing + ))] + MissingObjectStoreConfig { + object_store: ObjectStoreType, + missing: String, + }, + + // Creating a new S3 object store can fail if the region is *specified* but + // not *parseable* as a rusoto `Region`. The other object store constructors + // don't return `Result`. + #[snafu(display("Error configuring Amazon S3: {}", source))] + InvalidS3Config { source: object_store::Error }, + + #[snafu(display("Error configuring GCS: {}", source))] + InvalidGCSConfig { source: object_store::Error }, + + #[snafu(display("Error configuring Microsoft Azure: {}", source))] + InvalidAzureConfig { source: object_store::Error }, +} + +/// The AWS region to use for Amazon S3 based object storage if none is +/// specified. +pub const FALLBACK_AWS_REGION: &str = "us-east-1"; + +/// A `clap` `value_parser` which returns `None` when given an empty string and +/// `Some(NonEmptyString)` otherwise. +fn parse_optional_string(s: &str) -> Result, Infallible> { + Ok(NonEmptyString::new(s.to_string()).ok()) +} + +/// CLI config for object stores. +#[derive(Debug, Clone, clap::Parser)] +pub struct ObjectStoreConfig { + /// Which object storage to use. If not specified, defaults to memory. + /// + /// Possible values (case insensitive): + /// + /// * memory (default): Effectively no object persistence. + /// * memorythrottled: Like `memory` but with latency and throughput that somewhat resamble a cloud + /// object store. Useful for testing and benchmarking. + /// * file: Stores objects in the local filesystem. Must also set `--data-dir`. + /// * s3: Amazon S3. Must also set `--bucket`, `--aws-access-key-id`, `--aws-secret-access-key`, and + /// possibly `--aws-default-region`. + /// * google: Google Cloud Storage. Must also set `--bucket` and `--google-service-account`. + /// * azure: Microsoft Azure blob storage. Must also set `--bucket`, `--azure-storage-account`, + /// and `--azure-storage-access-key`. + #[clap( + value_enum, + long = "object-store", + env = "INFLUXDB_IOX_OBJECT_STORE", + ignore_case = true, + action, + verbatim_doc_comment + )] + pub object_store: Option, + + /// Name of the bucket to use for the object store. Must also set + /// `--object-store` to a cloud object storage to have any effect. + /// + /// If using Google Cloud Storage for the object store, this item as well + /// as `--google-service-account` must be set. + /// + /// If using S3 for the object store, must set this item as well + /// as `--aws-access-key-id` and `--aws-secret-access-key`. Can also set + /// `--aws-default-region` if not using the fallback region. + /// + /// If using Azure for the object store, set this item to the name of a + /// container you've created in the associated storage account, under + /// Blob Service > Containers. Must also set `--azure-storage-account` and + /// `--azure-storage-access-key`. + #[clap(long = "bucket", env = "INFLUXDB_IOX_BUCKET", action)] + pub bucket: Option, + + /// The location InfluxDB IOx will use to store files locally. + #[clap(long = "data-dir", env = "INFLUXDB_IOX_DB_DIR", action)] + pub database_directory: Option, + + /// When using Amazon S3 as the object store, set this to an access key that + /// has permission to read from and write to the specified S3 bucket. + /// + /// Must also set `--object-store=s3`, `--bucket`, and + /// `--aws-secret-access-key`. Can also set `--aws-default-region` if not + /// using the fallback region. + /// + /// Prefer the environment variable over the command line flag in shared + /// environments. + /// + /// An empty string value is equivalent to omitting the flag. + /// Note: must refer to std::option::Option explicitly, see + #[clap(long = "aws-access-key-id", env = "AWS_ACCESS_KEY_ID", value_parser = parse_optional_string, default_value="", action)] + pub aws_access_key_id: std::option::Option, + + /// When using Amazon S3 as the object store, set this to the secret access + /// key that goes with the specified access key ID. + /// + /// Must also set `--object-store=s3`, `--bucket`, `--aws-access-key-id`. + /// Can also set `--aws-default-region` if not using the fallback region. + /// + /// Prefer the environment variable over the command line flag in shared + /// environments. + /// + /// An empty string value is equivalent to omitting the flag. + /// Note: must refer to std::option::Option explicitly, see + #[clap(long = "aws-secret-access-key", env = "AWS_SECRET_ACCESS_KEY", value_parser = parse_optional_string, default_value = "", action)] + pub aws_secret_access_key: std::option::Option, + + /// When using Amazon S3 as the object store, set this to the region + /// that goes with the specified bucket if different from the fallback + /// value. + /// + /// Must also set `--object-store=s3`, `--bucket`, `--aws-access-key-id`, + /// and `--aws-secret-access-key`. + #[clap( + long = "aws-default-region", + env = "AWS_DEFAULT_REGION", + default_value = FALLBACK_AWS_REGION, + action, + )] + pub aws_default_region: String, + + /// When using Amazon S3 compatibility storage service, set this to the + /// endpoint. + /// + /// Must also set `--object-store=s3`, `--bucket`. Can also set `--aws-default-region` + /// if not using the fallback region. + /// + /// Prefer the environment variable over the command line flag in shared + /// environments. + #[clap(long = "aws-endpoint", env = "AWS_ENDPOINT", action)] + pub aws_endpoint: Option, + + /// When using Amazon S3 as an object store, set this to the session token. This is handy when using a federated + /// login / SSO and you fetch credentials via the UI. + /// + /// Is it assumed that the session is valid as long as the IOx server is running. + /// + /// Prefer the environment variable over the command line flag in shared + /// environments. + #[clap(long = "aws-session-token", env = "AWS_SESSION_TOKEN", action)] + pub aws_session_token: Option, + + /// Allow unencrypted HTTP connection to AWS. + #[clap(long = "aws-allow-http", env = "AWS_ALLOW_HTTP", action)] + pub aws_allow_http: bool, + + /// When using Google Cloud Storage as the object store, set this to the + /// path to the JSON file that contains the Google credentials. + /// + /// Must also set `--object-store=google` and `--bucket`. + #[clap( + long = "google-service-account", + env = "GOOGLE_SERVICE_ACCOUNT", + action + )] + pub google_service_account: Option, + + /// When using Microsoft Azure as the object store, set this to the + /// name you see when going to All Services > Storage accounts > `[name]`. + /// + /// Must also set `--object-store=azure`, `--bucket`, and + /// `--azure-storage-access-key`. + #[clap(long = "azure-storage-account", env = "AZURE_STORAGE_ACCOUNT", action)] + pub azure_storage_account: Option, + + /// When using Microsoft Azure as the object store, set this to one of the + /// Key values in the Storage account's Settings > Access keys. + /// + /// Must also set `--object-store=azure`, `--bucket`, and + /// `--azure-storage-account`. + /// + /// Prefer the environment variable over the command line flag in shared + /// environments. + #[clap( + long = "azure-storage-access-key", + env = "AZURE_STORAGE_ACCESS_KEY", + action + )] + pub azure_storage_access_key: Option, + + /// When using a network-based object store, limit the number of connection to this value. + #[clap( + long = "object-store-connection-limit", + env = "OBJECT_STORE_CONNECTION_LIMIT", + default_value = "16", + action + )] + pub object_store_connection_limit: NonZeroUsize, + + /// Optional config for the cache client. + #[clap(flatten)] + pub cache_config: Option, +} + +impl ObjectStoreConfig { + /// Create a new instance for all-in-one mode, only allowing some arguments. + pub fn new(database_directory: Option) -> Self { + match &database_directory { + Some(dir) => info!("Object store: File-based in `{}`", dir.display()), + None => info!("Object store: In-memory"), + } + + let object_store = database_directory.as_ref().map(|_| ObjectStoreType::File); + + Self { + aws_access_key_id: Default::default(), + aws_allow_http: Default::default(), + aws_default_region: Default::default(), + aws_endpoint: Default::default(), + aws_secret_access_key: Default::default(), + aws_session_token: Default::default(), + azure_storage_access_key: Default::default(), + azure_storage_account: Default::default(), + bucket: Default::default(), + database_directory, + google_service_account: Default::default(), + object_store, + object_store_connection_limit: NonZeroUsize::new(16).unwrap(), + cache_config: Default::default(), + } + } +} + +/// Object-store type. +#[derive(Debug, Copy, Clone, PartialEq, Eq, clap::ValueEnum)] +pub enum ObjectStoreType { + /// In-memory. + Memory, + + /// In-memory with additional throttling applied for testing + MemoryThrottled, + + /// Filesystem. + File, + + /// AWS S3. + S3, + + /// GCS. + Google, + + /// Azure object store. + Azure, +} + +#[cfg(feature = "gcp")] +fn new_gcs(config: &ObjectStoreConfig) -> Result, ParseError> { + use object_store::gcp::GoogleCloudStorageBuilder; + use object_store::limit::LimitStore; + + info!(bucket=?config.bucket, object_store_type="GCS", "Object Store"); + + let mut builder = GoogleCloudStorageBuilder::new(); + + if let Some(bucket) = &config.bucket { + builder = builder.with_bucket_name(bucket); + } + if let Some(account) = &config.google_service_account { + builder = builder.with_service_account_path(account); + } + + Ok(Arc::new(LimitStore::new( + builder.build().context(InvalidGCSConfigSnafu)?, + config.object_store_connection_limit.get(), + ))) +} + +#[cfg(not(feature = "gcp"))] +fn new_gcs(_: &ObjectStoreConfig) -> Result, ParseError> { + panic!("GCS support not enabled, recompile with the gcp feature enabled") +} + +#[cfg(feature = "aws")] +fn new_s3(config: &ObjectStoreConfig) -> Result, ParseError> { + use object_store::limit::LimitStore; + + info!( + bucket=?config.bucket, + endpoint=?config.aws_endpoint, + object_store_type="S3", + "Object Store" + ); + + Ok(Arc::new(LimitStore::new( + build_s3(config)?, + config.object_store_connection_limit.get(), + ))) +} + +#[cfg(feature = "aws")] +fn build_s3(config: &ObjectStoreConfig) -> Result { + use object_store::aws::AmazonS3Builder; + + let mut builder = AmazonS3Builder::new() + .with_allow_http(config.aws_allow_http) + .with_region(&config.aws_default_region) + .with_imdsv1_fallback(); + + if let Some(bucket) = &config.bucket { + builder = builder.with_bucket_name(bucket); + } + if let Some(key_id) = &config.aws_access_key_id { + builder = builder.with_access_key_id(key_id.get()); + } + if let Some(token) = &config.aws_session_token { + builder = builder.with_token(token); + } + if let Some(secret) = &config.aws_secret_access_key { + builder = builder.with_secret_access_key(secret.get()); + } + if let Some(endpoint) = &config.aws_endpoint { + builder = builder.with_endpoint(endpoint); + } + + builder.build().context(InvalidS3ConfigSnafu) +} + +#[cfg(not(feature = "aws"))] +fn new_s3(_: &ObjectStoreConfig) -> Result, ParseError> { + panic!("S3 support not enabled, recompile with the aws feature enabled") +} + +#[cfg(feature = "azure")] +fn new_azure(config: &ObjectStoreConfig) -> Result, ParseError> { + use object_store::azure::MicrosoftAzureBuilder; + use object_store::limit::LimitStore; + + info!(bucket=?config.bucket, account=?config.azure_storage_account, + object_store_type="Azure", "Object Store"); + + let mut builder = MicrosoftAzureBuilder::new(); + + if let Some(bucket) = &config.bucket { + builder = builder.with_container_name(bucket); + } + if let Some(account) = &config.azure_storage_account { + builder = builder.with_account(account) + } + if let Some(key) = &config.azure_storage_access_key { + builder = builder.with_access_key(key) + } + + Ok(Arc::new(LimitStore::new( + builder.build().context(InvalidAzureConfigSnafu)?, + config.object_store_connection_limit.get(), + ))) +} + +#[cfg(not(feature = "azure"))] +fn new_azure(_: &ObjectStoreConfig) -> Result, ParseError> { + panic!("Azure blob storage support not enabled, recompile with the azure feature enabled") +} + +/// Create config-dependant object store. +pub fn make_object_store(config: &ObjectStoreConfig) -> Result, ParseError> { + if let Some(data_dir) = &config.database_directory { + if !matches!(&config.object_store, Some(ObjectStoreType::File)) { + warn!(?data_dir, object_store_type=?config.object_store, + "--data-dir / `INFLUXDB_IOX_DB_DIR` ignored. It only affects 'file' object stores"); + } + } + + let remote_store: Arc = match &config.object_store { + Some(ObjectStoreType::Memory) | None => { + info!(object_store_type = "Memory", "Object Store"); + Arc::new(InMemory::new()) + } + Some(ObjectStoreType::MemoryThrottled) => { + let config = ThrottleConfig { + // for every call: assume a 100ms latency + wait_delete_per_call: Duration::from_millis(100), + wait_get_per_call: Duration::from_millis(100), + wait_list_per_call: Duration::from_millis(100), + wait_list_with_delimiter_per_call: Duration::from_millis(100), + wait_put_per_call: Duration::from_millis(100), + + // for list operations: assume we need 1 call per 1k entries at 100ms + wait_list_per_entry: Duration::from_millis(100) / 1_000, + wait_list_with_delimiter_per_entry: Duration::from_millis(100) / 1_000, + + // for upload/download: assume 1GByte/s + wait_get_per_byte: Duration::from_secs(1) / 1_000_000_000, + }; + + info!(?config, object_store_type = "Memory", "Object Store"); + Arc::new(ThrottledStore::new(InMemory::new(), config)) + } + + Some(ObjectStoreType::Google) => new_gcs(config)?, + Some(ObjectStoreType::S3) => new_s3(config)?, + Some(ObjectStoreType::Azure) => new_azure(config)?, + Some(ObjectStoreType::File) => match config.database_directory.as_ref() { + Some(db_dir) => { + info!(?db_dir, object_store_type = "Directory", "Object Store"); + fs::create_dir_all(db_dir) + .context(CreatingDatabaseDirectorySnafu { path: db_dir })?; + + let store = object_store::local::LocalFileSystem::new_with_prefix(db_dir) + .context(CreateLocalFileSystemSnafu { path: db_dir })?; + Arc::new(store) + } + None => MissingObjectStoreConfigSnafu { + object_store: ObjectStoreType::File, + missing: "data-dir", + } + .fail()?, + }, + }; + + if let Some(cache_config) = &config.cache_config { + let cache = parquet_cache::make_client( + cache_config.namespace_addr.clone(), + Arc::clone(&remote_store), + ); + info!(?cache_config, "Parquet cache enabled"); + Ok(cache) + } else { + Ok(remote_store) + } +} + +/// The `object_store::signer::Signer` trait is only implemented for AWS currently, so when the AWS +/// feature is enabled and the configured object store is S3, return a signer. +#[cfg(feature = "aws")] +pub fn make_presigned_url_signer( + config: &ObjectStoreConfig, +) -> Result>, ParseError> { + match &config.object_store { + Some(ObjectStoreType::S3) => Ok(Some(Arc::new(build_s3(config)?))), + _ => Ok(None), + } +} + +/// The `object_store::signer::Signer` trait is only implemented for AWS currently, so if the AWS +/// feature isn't enabled, don't return a signer. +#[cfg(not(feature = "aws"))] +pub fn make_presigned_url_signer( + _config: &ObjectStoreConfig, +) -> Result>, ParseError> { + Ok(None) +} + +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +pub enum CheckError { + #[snafu(display("Cannot read from object store: {}", source))] + CannotReadObjectStore { source: object_store::Error }, +} + +/// Check if object store is properly configured and accepts writes and reads. +/// +/// Note: This does NOT test if the object store is writable! +pub async fn check_object_store(object_store: &DynObjectStore) -> Result<(), CheckError> { + // Use some prefix that will very likely end in an empty result, so we don't pull too much actual data here. + let uuid = Uuid::new_v4().to_string(); + let prefix = Path::from_iter([uuid]); + + // create stream (this might fail if the store is not readable) + let mut stream = object_store.list(Some(&prefix)); + + // ... but sometimes it fails only if we use the resulting stream, so try that once + stream + .try_next() + .await + .context(CannotReadObjectStoreSnafu)?; + + // store seems to be readable + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use clap::Parser; + use std::env; + use tempfile::TempDir; + + #[test] + fn default_object_store_is_memory() { + let config = ObjectStoreConfig::try_parse_from(["server"]).unwrap(); + + let object_store = make_object_store(&config).unwrap(); + assert_eq!(&object_store.to_string(), "InMemory") + } + + #[test] + fn explicitly_set_object_store_to_memory() { + let config = + ObjectStoreConfig::try_parse_from(["server", "--object-store", "memory"]).unwrap(); + + let object_store = make_object_store(&config).unwrap(); + assert_eq!(&object_store.to_string(), "InMemory") + } + + #[test] + fn default_url_signer_is_none() { + let config = ObjectStoreConfig::try_parse_from(["server"]).unwrap(); + + let signer = make_presigned_url_signer(&config).unwrap(); + assert!(signer.is_none(), "Expected None, got {signer:?}"); + } + + #[test] + #[cfg(feature = "aws")] + fn valid_s3_config() { + let config = ObjectStoreConfig::try_parse_from([ + "server", + "--object-store", + "s3", + "--bucket", + "mybucket", + "--aws-access-key-id", + "NotARealAWSAccessKey", + "--aws-secret-access-key", + "NotARealAWSSecretAccessKey", + ]) + .unwrap(); + + let object_store = make_object_store(&config).unwrap(); + assert_eq!( + &object_store.to_string(), + "LimitStore(16, AmazonS3(mybucket))" + ) + } + + #[test] + #[cfg(feature = "aws")] + fn s3_config_missing_params() { + let mut config = + ObjectStoreConfig::try_parse_from(["server", "--object-store", "s3"]).unwrap(); + + // clean out eventual leaks via env variables + config.bucket = None; + + let err = make_object_store(&config).unwrap_err().to_string(); + + assert_eq!( + err, + "Error configuring Amazon S3: Generic S3 error: Missing bucket name" + ); + } + + #[test] + #[cfg(feature = "aws")] + fn valid_s3_url_signer() { + let config = ObjectStoreConfig::try_parse_from([ + "server", + "--object-store", + "s3", + "--bucket", + "mybucket", + "--aws-access-key-id", + "NotARealAWSAccessKey", + "--aws-secret-access-key", + "NotARealAWSSecretAccessKey", + ]) + .unwrap(); + + assert!(make_presigned_url_signer(&config).unwrap().is_some()); + + // Even with the aws feature on, any other object store shouldn't create a signer. + let root = TempDir::new().unwrap(); + let root_path = root.path().to_str().unwrap(); + + let config = ObjectStoreConfig::try_parse_from([ + "server", + "--object-store", + "file", + "--data-dir", + root_path, + ]) + .unwrap(); + + let signer = make_presigned_url_signer(&config).unwrap(); + assert!(signer.is_none(), "Expected None, got {signer:?}"); + } + + #[test] + #[cfg(feature = "aws")] + fn s3_url_signer_config_missing_params() { + let mut config = + ObjectStoreConfig::try_parse_from(["server", "--object-store", "s3"]).unwrap(); + + // clean out eventual leaks via env variables + config.bucket = None; + + let err = make_presigned_url_signer(&config).unwrap_err().to_string(); + + assert_eq!( + err, + "Error configuring Amazon S3: Generic S3 error: Missing bucket name" + ); + } + + #[test] + #[cfg(feature = "gcp")] + fn valid_google_config() { + use std::io::Write; + use tempfile::NamedTempFile; + + let mut file = NamedTempFile::new().expect("tempfile should be created"); + const FAKE_KEY: &str = r#"{"private_key": "private_key", "private_key_id": "private_key_id", "client_email":"client_email", "disable_oauth":true}"#; + writeln!(file, "{FAKE_KEY}").unwrap(); + let path = file.path().to_str().expect("file path should exist"); + + let config = ObjectStoreConfig::try_parse_from([ + "server", + "--object-store", + "google", + "--bucket", + "mybucket", + "--google-service-account", + path, + ]) + .unwrap(); + + let object_store = make_object_store(&config).unwrap(); + assert_eq!( + &object_store.to_string(), + "LimitStore(16, GoogleCloudStorage(mybucket))" + ) + } + + #[test] + #[cfg(feature = "gcp")] + fn google_config_missing_params() { + let mut config = + ObjectStoreConfig::try_parse_from(["server", "--object-store", "google"]).unwrap(); + + // clean out eventual leaks via env variables + config.bucket = None; + + let err = make_object_store(&config).unwrap_err().to_string(); + + assert_eq!( + err, + "Error configuring GCS: Generic GCS error: Missing bucket name" + ); + } + + #[test] + #[cfg(feature = "azure")] + fn valid_azure_config() { + let config = ObjectStoreConfig::try_parse_from([ + "server", + "--object-store", + "azure", + "--bucket", + "mybucket", + "--azure-storage-account", + "NotARealStorageAccount", + "--azure-storage-access-key", + "Zm9vYmFy", // base64 encoded "foobar" + ]) + .unwrap(); + + let object_store = make_object_store(&config).unwrap(); + assert_eq!(&object_store.to_string(), "LimitStore(16, MicrosoftAzure { account: NotARealStorageAccount, container: mybucket })") + } + + #[test] + #[cfg(feature = "azure")] + fn azure_config_missing_params() { + let mut config = + ObjectStoreConfig::try_parse_from(["server", "--object-store", "azure"]).unwrap(); + + // clean out eventual leaks via env variables + config.bucket = None; + + let err = make_object_store(&config).unwrap_err().to_string(); + + assert_eq!( + err, + "Error configuring Microsoft Azure: Generic MicrosoftAzure error: Container name must be specified" + ); + } + + #[test] + fn valid_file_config() { + let root = TempDir::new().unwrap(); + let root_path = root.path().to_str().unwrap(); + + let config = ObjectStoreConfig::try_parse_from([ + "server", + "--object-store", + "file", + "--data-dir", + root_path, + ]) + .unwrap(); + + let object_store = make_object_store(&config).unwrap().to_string(); + assert!( + object_store.starts_with("LocalFileSystem"), + "{}", + object_store + ) + } + + #[test] + fn file_config_missing_params() { + // this test tests for failure to configure the object store because of data-dir configuration missing + // if the INFLUXDB_IOX_DB_DIR env variable is set, the test fails because the configuration is + // actually present. + env::remove_var("INFLUXDB_IOX_DB_DIR"); + let config = + ObjectStoreConfig::try_parse_from(["server", "--object-store", "file"]).unwrap(); + + let err = make_object_store(&config).unwrap_err().to_string(); + + assert_eq!( + err, + "Specified File for the object store, required configuration missing for \ + data-dir" + ); + } + + #[test] + fn valid_cache_config() { + let root = TempDir::new().unwrap(); + let root_path = root.path().to_str().unwrap(); + + let config = ObjectStoreConfig::try_parse_from([ + "server", + "--object-store", + "file", + "--data-dir", + root_path, + "--parquet-cache-namespace-addr", + "http://k8s-noninstance-general-service-route:8080", + ]) + .unwrap(); + + let object_store = make_object_store(&config).unwrap().to_string(); + assert!( + object_store.starts_with("DataCacheObjectStore"), + "{}", + object_store + ) + } +} diff --git a/clap_blocks/src/parquet_cache.rs b/clap_blocks/src/parquet_cache.rs new file mode 100644 index 0000000..d93aa94 --- /dev/null +++ b/clap_blocks/src/parquet_cache.rs @@ -0,0 +1,57 @@ +//! CLI handling for parquet data cache config (via CLI arguments and environment variables). + +/// Config for cache client. +#[derive(Debug, Clone, Default, clap::Parser)] +pub struct ParquetCacheClientConfig { + /// The address for the service namespace (not a given instance). + /// + /// When the client comes online, it discovers the keyspace + /// by issue requests to this address. + #[clap( + long = "parquet-cache-namespace-addr", + env = "INFLUXDB_IOX_PARQUET_CACHE_NAMESPACE_ADDR", + required = false + )] + pub namespace_addr: String, +} + +/// Config for cache instance. +#[derive(Debug, Clone, Default, clap::Parser)] +pub struct ParquetCacheInstanceConfig { + /// The path to the config file for the keyspace. + #[clap( + long = "parquet-cache-keyspace-config-path", + env = "INFLUXDB_IOX_PARQUET_CACHE_KEYSPACE_CONFIG_PATH", + required = true + )] + pub keyspace_config_path: String, + + /// The hostname of the cache instance (k8s pod) running this process. + /// + /// Cache controller should be setting this env var. + #[clap( + long = "parquet-cache-instance-hostname", + env = "HOSTNAME", + required = true + )] + pub instance_hostname: String, + + /// The local directory to store data. + #[clap( + long = "parquet-cache-local-dir", + env = "INFLUXDB_IOX_PARQUET_CACHE_LOCAL_DIR", + required = true + )] + pub local_dir: String, +} + +impl From for parquet_cache::ParquetCacheServerConfig { + fn from(instance_config: ParquetCacheInstanceConfig) -> Self { + Self { + keyspace_config_path: instance_config.keyspace_config_path, + hostname: instance_config.instance_hostname, + local_dir: instance_config.local_dir, + policy_config: Default::default(), + } + } +} diff --git a/clap_blocks/src/querier.rs b/clap_blocks/src/querier.rs new file mode 100644 index 0000000..4a62455 --- /dev/null +++ b/clap_blocks/src/querier.rs @@ -0,0 +1,264 @@ +//! Querier-related configs. + +use crate::{ + ingester_address::IngesterAddress, + memory_size::MemorySize, + single_tenant::{CONFIG_AUTHZ_ENV_NAME, CONFIG_AUTHZ_FLAG}, +}; +use std::{collections::HashMap, num::NonZeroUsize}; + +/// CLI config for querier configuration +#[derive(Debug, Clone, PartialEq, Eq, clap::Parser)] +pub struct QuerierConfig { + /// Addr for connection to authz + #[clap(long = CONFIG_AUTHZ_FLAG, env = CONFIG_AUTHZ_ENV_NAME)] + pub authz_address: Option, + + /// The number of threads to use for queries. + /// + /// If not specified, defaults to the number of cores on the system + #[clap( + long = "num-query-threads", + env = "INFLUXDB_IOX_NUM_QUERY_THREADS", + action + )] + pub num_query_threads: Option, + + /// Size of memory pool used during query exec, in bytes. + /// + /// If queries attempt to allocate more than this many bytes + /// during execution, they will error with "ResourcesExhausted". + /// + /// Can be given as absolute value or in percentage of the total available memory (e.g. `10%`). + #[clap( + long = "exec-mem-pool-bytes", + env = "INFLUXDB_IOX_EXEC_MEM_POOL_BYTES", + default_value = "8589934592", // 8GB + action + )] + pub exec_mem_pool_bytes: MemorySize, + + /// gRPC address for the router to talk with the ingesters. For + /// example: + /// + /// "http://127.0.0.1:8083" + /// + /// or + /// + /// "http://10.10.10.1:8083,http://10.10.10.2:8083" + /// + /// for multiple addresses. + #[clap( + long = "ingester-addresses", + env = "INFLUXDB_IOX_INGESTER_ADDRESSES", + required = false, + num_args = 0.., + value_delimiter = ',' + )] + pub ingester_addresses: Vec, + + /// Size of the RAM cache used to store catalog metadata information in bytes. + /// + /// Can be given as absolute value or in percentage of the total available memory (e.g. `10%`). + #[clap( + long = "ram-pool-metadata-bytes", + env = "INFLUXDB_IOX_RAM_POOL_METADATA_BYTES", + default_value = "134217728", // 128MB + action + )] + pub ram_pool_metadata_bytes: MemorySize, + + /// Size of the RAM cache used to store data in bytes. + /// + /// Can be given as absolute value or in percentage of the total available memory (e.g. `10%`). + #[clap( + long = "ram-pool-data-bytes", + env = "INFLUXDB_IOX_RAM_POOL_DATA_BYTES", + default_value = "1073741824", // 1GB + action + )] + pub ram_pool_data_bytes: MemorySize, + + /// Limit the number of concurrent queries. + #[clap( + long = "max-concurrent-queries", + env = "INFLUXDB_IOX_MAX_CONCURRENT_QUERIES", + default_value = "10", + action + )] + pub max_concurrent_queries: usize, + + /// After how many ingester query errors should the querier enter circuit breaker mode? + /// + /// The querier normally contacts the ingester for any unpersisted data during query planning. + /// However, when the ingester can not be contacted for some reason, the querier will begin + /// returning results that do not include unpersisted data and enter "circuit breaker mode" + /// to avoid continually retrying the failing connection on subsequent queries. + /// + /// If circuits are open, the querier will NOT contact the ingester and no unpersisted data + /// will be presented to the user. + /// + /// Circuits will switch to "half open" after some jittered timeout and the querier will try to + /// use the ingester in question again. If this succeeds, we are back to normal, otherwise it + /// will back off exponentially before trying again (and again ...). + /// + /// In a production environment the `ingester_circuit_state` metric should be monitored. + #[clap( + long = "ingester-circuit-breaker-threshold", + env = "INFLUXDB_IOX_INGESTER_CIRCUIT_BREAKER_THRESHOLD", + default_value = "10", + action + )] + pub ingester_circuit_breaker_threshold: u64, + + /// DataFusion config. + #[clap( + long = "datafusion-config", + env = "INFLUXDB_IOX_DATAFUSION_CONFIG", + default_value = "", + value_parser = parse_datafusion_config, + action + )] + pub datafusion_config: HashMap, + + /// Use the new V2 API to talk to the ingester. + /// + /// Defaults to "no". + /// + /// See . + #[clap(long = "v2-ingester-api", env = "INFLUXDB_IOX_V2_INGESTER_API", action)] + pub v2_ingester_api: bool, +} + +fn parse_datafusion_config( + s: &str, +) -> Result, Box> { + let s = s.trim(); + if s.is_empty() { + return Ok(HashMap::with_capacity(0)); + } + + let mut out = HashMap::new(); + for part in s.split(',') { + let kv = part.trim().splitn(2, ':').collect::>(); + match kv.as_slice() { + [key, value] => { + let key_owned = key.trim().to_owned(); + let value_owned = value.trim().to_owned(); + let existed = out.insert(key_owned, value_owned).is_some(); + if existed { + return Err(format!("key '{key}' passed multiple times").into()); + } + } + _ => { + return Err( + format!("Invalid key value pair - expected 'KEY:VALUE' got '{s}'").into(), + ); + } + } + } + + Ok(out) +} + +#[cfg(test)] +mod tests { + use super::*; + use clap::Parser; + use test_helpers::assert_contains; + + #[test] + fn test_default() { + let actual = QuerierConfig::try_parse_from(["my_binary"]).unwrap(); + + assert_eq!(actual.num_query_threads, None); + assert!(actual.ingester_addresses.is_empty()); + assert!(actual.datafusion_config.is_empty()); + } + + #[test] + fn test_num_threads() { + let actual = + QuerierConfig::try_parse_from(["my_binary", "--num-query-threads", "42"]).unwrap(); + + assert_eq!( + actual.num_query_threads, + Some(NonZeroUsize::new(42).unwrap()) + ); + } + + #[test] + fn test_ingester_addresses_list() { + let querier = QuerierConfig::try_parse_from([ + "my_binary", + "--ingester-addresses", + "http://ingester-0:8082,http://ingester-1:8082", + ]) + .unwrap(); + + let actual: Vec<_> = querier + .ingester_addresses + .iter() + .map(ToString::to_string) + .collect(); + + let expected = vec!["http://ingester-0:8082/", "http://ingester-1:8082/"]; + assert_eq!(actual, expected); + } + + #[test] + fn bad_ingester_addresses_list() { + let actual = QuerierConfig::try_parse_from([ + "my_binary", + "--ingester-addresses", + "\\ingester-0:8082", + ]) + .unwrap_err() + .to_string(); + + assert_contains!( + actual, + "error: \ + invalid value '\\ingester-0:8082' \ + for '--ingester-addresses [...]': \ + invalid uri character" + ); + } + + #[test] + fn test_datafusion_config() { + let actual = QuerierConfig::try_parse_from([ + "my_binary", + "--datafusion-config= foo : bar , x:y:z ", + ]) + .unwrap(); + + assert_eq!( + actual.datafusion_config, + HashMap::from([ + (String::from("foo"), String::from("bar")), + (String::from("x"), String::from("y:z")), + ]), + ); + } + + #[test] + fn bad_datafusion_config() { + let actual = QuerierConfig::try_parse_from(["my_binary", "--datafusion-config=foo"]) + .unwrap_err() + .to_string(); + assert_contains!( + actual, + "error: invalid value 'foo' for '--datafusion-config ': Invalid key value pair - expected 'KEY:VALUE' got 'foo'" + ); + + let actual = + QuerierConfig::try_parse_from(["my_binary", "--datafusion-config=foo:bar,baz:1,foo:2"]) + .unwrap_err() + .to_string(); + assert_contains!( + actual, + "error: invalid value 'foo:bar,baz:1,foo:2' for '--datafusion-config ': key 'foo' passed multiple times" + ); + } +} diff --git a/clap_blocks/src/router.rs b/clap_blocks/src/router.rs new file mode 100644 index 0000000..28442d7 --- /dev/null +++ b/clap_blocks/src/router.rs @@ -0,0 +1,165 @@ +//! CLI config for the router using the RPC write path + +use crate::{ + bulk_ingest::BulkIngestConfig, + gossip::GossipConfig, + ingester_address::IngesterAddress, + single_tenant::{ + CONFIG_AUTHZ_ENV_NAME, CONFIG_AUTHZ_FLAG, CONFIG_CST_ENV_NAME, CONFIG_CST_FLAG, + }, +}; +use std::{ + num::{NonZeroUsize, ParseIntError}, + time::Duration, +}; + +/// CLI config for the router using the RPC write path +#[derive(Debug, Clone, clap::Parser)] +#[allow(missing_copy_implementations)] +pub struct RouterConfig { + /// Gossip config. + #[clap(flatten)] + pub gossip_config: GossipConfig, + + /// Bulk ingest API config. + #[clap(flatten)] + pub bulk_ingest_config: BulkIngestConfig, + + /// Addr for connection to authz + #[clap( + long = CONFIG_AUTHZ_FLAG, + env = CONFIG_AUTHZ_ENV_NAME, + requires("single_tenant_deployment"), + )] + pub authz_address: Option, + + /// Differential handling based upon deployment to CST vs MT. + /// + /// At minimum, differs in supports of v1 endpoint. But also includes + /// differences in namespace handling, etc. + #[clap( + long = CONFIG_CST_FLAG, + env = CONFIG_CST_ENV_NAME, + default_value = "false", + requires_if("true", "authz_address") + )] + pub single_tenant_deployment: bool, + + /// The maximum number of simultaneous requests the HTTP server is + /// configured to accept. + /// + /// This number of requests, multiplied by the maximum request body size the + /// HTTP server is configured with gives the rough amount of memory a HTTP + /// server will use to buffer request bodies in memory. + /// + /// A default maximum of 200 requests, multiplied by the default 10MiB + /// maximum for HTTP request bodies == ~2GiB. + #[clap( + long = "max-http-requests", + env = "INFLUXDB_IOX_MAX_HTTP_REQUESTS", + default_value = "200", + action + )] + pub http_request_limit: usize, + + /// When writing line protocol data, does an error on a single line + /// reject the write? Or will all individual valid lines be written? + /// Set to true to enable all valid lines to write. + #[clap( + long = "partial-writes-enabled", + env = "INFLUXDB_IOX_PARTIAL_WRITES_ENABLED", + default_value = "false", + action + )] + pub permit_partial_writes: bool, + + /// gRPC address for the router to talk with the ingesters. For + /// example: + /// + /// "http://127.0.0.1:8083" + /// + /// or + /// + /// "http://10.10.10.1:8083,http://10.10.10.2:8083" + /// + /// for multiple addresses. + #[clap( + long = "ingester-addresses", + env = "INFLUXDB_IOX_INGESTER_ADDRESSES", + required = true, + num_args=1.., + value_delimiter = ',' + )] + pub ingester_addresses: Vec, + + /// Retention period to use when auto-creating namespaces. + /// For infinite retention, leave this unset and it will default to `None`. + /// Setting it to zero will not make it infinite. + /// Ignored if namespace-autocreation-enabled is set to false. + #[clap( + long = "new-namespace-retention-hours", + env = "INFLUXDB_IOX_NEW_NAMESPACE_RETENTION_HOURS", + action + )] + pub new_namespace_retention_hours: Option, + + /// When writing data to a non-existent namespace, should the router auto-create the namespace + /// or reject the write? Set to false to disable namespace autocreation. + #[clap( + long = "namespace-autocreation-enabled", + env = "INFLUXDB_IOX_NAMESPACE_AUTOCREATION_ENABLED", + default_value = "true", + action + )] + pub namespace_autocreation_enabled: bool, + + /// Specify the timeout in seconds for a single RPC write request to an + /// ingester. + #[clap( + long = "rpc-write-timeout-seconds", + env = "INFLUXDB_IOX_RPC_WRITE_TIMEOUT_SECONDS", + default_value = "3", + value_parser = parse_duration + )] + pub rpc_write_timeout_seconds: Duration, + + /// Specify the maximum allowed outgoing RPC write message size when + /// communicating with the Ingester. + #[clap( + long = "rpc-write-max-outgoing-bytes", + env = "INFLUXDB_IOX_RPC_WRITE_MAX_OUTGOING_BYTES", + default_value = "104857600", // 100MiB + )] + pub rpc_write_max_outgoing_bytes: usize, + + /// Enable optional replication for each RPC write. + /// + /// This value specifies the total number of copies of data after + /// replication, defaulting to 1. + /// + /// If the desired replication level is not achieved, a partial write error + /// will be returned to the user. The write MAY be queryable after a partial + /// write failure. + #[clap( + long = "rpc-write-replicas", + env = "INFLUXDB_IOX_RPC_WRITE_REPLICAS", + default_value = "1" + )] + pub rpc_write_replicas: NonZeroUsize, + + /// Specify the maximum number of probe requests to be sent per second. + /// + /// At least 20% of these requests must succeed within a second for the + /// endpoint to be considered healthy. + #[clap( + long = "rpc-write-health-num-probes", + env = "INFLUXDB_IOX_RPC_WRITE_HEALTH_NUM_PROBES", + default_value = "10" + )] + pub rpc_write_health_num_probes: u64, +} + +/// Map a string containing an integer number of seconds into a [`Duration`]. +fn parse_duration(input: &str) -> Result { + input.parse().map(Duration::from_secs) +} diff --git a/clap_blocks/src/run_config.rs b/clap_blocks/src/run_config.rs new file mode 100644 index 0000000..e5c5939 --- /dev/null +++ b/clap_blocks/src/run_config.rs @@ -0,0 +1,107 @@ +//! Common config for all `run` commands. +use trace_exporters::TracingConfig; +use trogging::cli::LoggingConfig; + +use crate::{object_store::ObjectStoreConfig, socket_addr::SocketAddr}; + +/// The default bind address for the HTTP API. +pub const DEFAULT_API_BIND_ADDR: &str = "127.0.0.1:8080"; + +/// The default bind address for the gRPC. +pub const DEFAULT_GRPC_BIND_ADDR: &str = "127.0.0.1:8082"; + +/// Common config for all `run` commands. +#[derive(Debug, Clone, clap::Parser)] +pub struct RunConfig { + /// logging options + #[clap(flatten)] + pub(crate) logging_config: LoggingConfig, + + /// tracing options + #[clap(flatten)] + pub(crate) tracing_config: TracingConfig, + + /// The address on which IOx will serve HTTP API requests. + #[clap( + long = "api-bind", + env = "INFLUXDB_IOX_BIND_ADDR", + default_value = DEFAULT_API_BIND_ADDR, + action, + )] + pub http_bind_address: SocketAddr, + + /// The address on which IOx will serve Storage gRPC API requests. + #[clap( + long = "grpc-bind", + env = "INFLUXDB_IOX_GRPC_BIND_ADDR", + default_value = DEFAULT_GRPC_BIND_ADDR, + action, + )] + pub grpc_bind_address: SocketAddr, + + /// Maximum size of HTTP requests. + #[clap( + long = "max-http-request-size", + env = "INFLUXDB_IOX_MAX_HTTP_REQUEST_SIZE", + default_value = "10485760", // 10 MiB + action, + )] + pub max_http_request_size: usize, + + /// object store config + #[clap(flatten)] + pub(crate) object_store_config: ObjectStoreConfig, +} + +impl RunConfig { + /// Get a reference to the run config's tracing config. + pub fn tracing_config(&self) -> &TracingConfig { + &self.tracing_config + } + + /// Get a reference to the run config's object store config. + pub fn object_store_config(&self) -> &ObjectStoreConfig { + &self.object_store_config + } + + /// Get a mutable reference to the run config's tracing config. + pub fn tracing_config_mut(&mut self) -> &mut TracingConfig { + &mut self.tracing_config + } + + /// Get a reference to the run config's logging config. + pub fn logging_config(&self) -> &LoggingConfig { + &self.logging_config + } + + /// set the http bind address + pub fn with_http_bind_address(mut self, http_bind_address: SocketAddr) -> Self { + self.http_bind_address = http_bind_address; + self + } + + /// set the grpc bind address + pub fn with_grpc_bind_address(mut self, grpc_bind_address: SocketAddr) -> Self { + self.grpc_bind_address = grpc_bind_address; + self + } + + /// Create a new instance for all-in-one mode, only allowing some arguments. + pub fn new( + logging_config: LoggingConfig, + tracing_config: TracingConfig, + http_bind_address: SocketAddr, + grpc_bind_address: SocketAddr, + max_http_request_size: usize, + object_store_config: ObjectStoreConfig, + ) -> Self { + Self { + logging_config, + tracing_config, + http_bind_address, + grpc_bind_address, + max_http_request_size, + object_store_config, + } + } +} diff --git a/clap_blocks/src/single_tenant.rs b/clap_blocks/src/single_tenant.rs new file mode 100644 index 0000000..fb7fb95 --- /dev/null +++ b/clap_blocks/src/single_tenant.rs @@ -0,0 +1,11 @@ +//! CLI config for request authorization. + +/// Env var providing authz address +pub const CONFIG_AUTHZ_ENV_NAME: &str = "INFLUXDB_IOX_AUTHZ_ADDR"; +/// CLI flag for authz address +pub const CONFIG_AUTHZ_FLAG: &str = "authz-addr"; + +/// Env var for single tenancy deployments +pub const CONFIG_CST_ENV_NAME: &str = "INFLUXDB_IOX_SINGLE_TENANCY"; +/// CLI flag for single tenancy deployments +pub const CONFIG_CST_FLAG: &str = "single-tenancy"; diff --git a/clap_blocks/src/socket_addr.rs b/clap_blocks/src/socket_addr.rs new file mode 100644 index 0000000..02a1014 --- /dev/null +++ b/clap_blocks/src/socket_addr.rs @@ -0,0 +1,77 @@ +//! Config for socket addresses. +use std::{net::ToSocketAddrs, ops::Deref}; + +/// Parsable socket address. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct SocketAddr(std::net::SocketAddr); + +impl Deref for SocketAddr { + type Target = std::net::SocketAddr; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::fmt::Display for SocketAddr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl std::str::FromStr for SocketAddr { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_socket_addrs() { + Ok(mut addrs) => { + if let Some(addr) = addrs.next() { + Ok(Self(addr)) + } else { + Err(format!("Found no addresses for '{s}'")) + } + } + Err(e) => Err(format!("Cannot parse socket address '{s}': {e}")), + } + } +} + +impl From for std::net::SocketAddr { + fn from(addr: SocketAddr) -> Self { + addr.0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::{ + net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}, + str::FromStr, + }; + + #[test] + fn test_socketaddr() { + let addr: std::net::SocketAddr = SocketAddr::from_str("127.0.0.1:1234").unwrap().into(); + assert_eq!(addr, std::net::SocketAddr::from(([127, 0, 0, 1], 1234)),); + + let addr: std::net::SocketAddr = SocketAddr::from_str("localhost:1234").unwrap().into(); + // depending on where the test runs, localhost will either resolve to a ipv4 or + // an ipv6 addr. + match addr { + std::net::SocketAddr::V4(so) => { + assert_eq!(so, SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1234)) + } + std::net::SocketAddr::V6(so) => assert_eq!( + so, + SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 1234, 0, 0) + ), + }; + + assert_eq!( + SocketAddr::from_str("!@INv_a1d(ad0/resp_!").unwrap_err(), + "Cannot parse socket address '!@INv_a1d(ad0/resp_!': invalid socket address", + ); + } +} diff --git a/client_util/Cargo.toml b/client_util/Cargo.toml new file mode 100644 index 0000000..8b2e12f --- /dev/null +++ b/client_util/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "client_util" +description = "Shared code for IOx clients" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[lints] +workspace = true + +[dependencies] +http = "0.2.11" +reqwest = { version = "0.11", default-features = false, features = ["stream", "rustls-tls-native-roots"] } +thiserror = "1.0.56" +tonic = { workspace = true } +tower = "0.4" +workspace-hack = { version = "0.1", path = "../workspace-hack" } + +[dev-dependencies] +tokio = { version = "1.35", features = ["macros", "parking_lot", "rt-multi-thread"] } +mockito = { version = "1.2", default-features = false } diff --git a/client_util/src/connection.rs b/client_util/src/connection.rs new file mode 100644 index 0000000..d671502 --- /dev/null +++ b/client_util/src/connection.rs @@ -0,0 +1,295 @@ +use crate::tower::{SetRequestHeadersLayer, SetRequestHeadersService}; +use http::header::HeaderName; +use http::HeaderMap; +use http::{uri::InvalidUri, HeaderValue, Uri}; +use std::convert::TryInto; +use std::time::Duration; +use thiserror::Error; +use tonic::transport::{Channel, Endpoint}; +use tower::make::MakeConnection; + +/// The connection type used for clients. Use [`Builder`] to create +/// instances of [`Connection`] objects +#[derive(Debug, Clone)] +pub struct Connection { + grpc_connection: GrpcConnection, + http_connection: HttpConnection, +} + +impl Connection { + /// Create a new Connection + fn new(grpc_connection: GrpcConnection, http_connection: HttpConnection) -> Self { + Self { + grpc_connection, + http_connection, + } + } + + /// Consume `self` and return a [`GrpcConnection`] (suitable for use in + /// tonic clients) + pub fn into_grpc_connection(self) -> GrpcConnection { + self.grpc_connection + } + + /// Consume `self` and return a [`HttpConnection`] (suitable for making + /// calls to /api/v2 endpoints) + pub fn into_http_connection(self) -> HttpConnection { + self.http_connection + } +} + +/// The type used to make tonic (gRPC) requests +pub type GrpcConnection = SetRequestHeadersService; + +/// The type used to make raw http request +#[derive(Debug, Clone)] +pub struct HttpConnection { + /// The base uri of the IOx http API endpoint + uri: Uri, + /// http client connection + http_client: reqwest::Client, +} + +impl HttpConnection { + fn new(uri: Uri, http_client: reqwest::Client) -> Self { + Self { uri, http_client } + } + + /// Return a reference to the underyling http client + pub fn client(&self) -> &reqwest::Client { + &self.http_client + } + + /// Return a reference to the base uri of the IOx http API endpoint + pub fn uri(&self) -> &Uri { + &self.uri + } +} + +/// The default User-Agent header sent by the HTTP client. +pub const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); +/// The default connection timeout +pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(1); +/// The default request timeout +pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); + +/// Errors returned by the ConnectionBuilder +#[derive(Debug, Error)] +pub enum Error { + /// Server returned an invalid argument error + #[error("Connection error: {}{}", source, details)] + TransportError { + /// underlying [`tonic::transport::Error`] + source: tonic::transport::Error, + /// stringified version of the tonic error's source + details: String, + }, + + /// Client received an unexpected error from the server + #[error("Invalid URI: {}", .0)] + InvalidUri(#[from] InvalidUri), +} + +// Custom impl to include underlying source (not included in tonic +// transport error) +impl From for Error { + fn from(source: tonic::transport::Error) -> Self { + use std::error::Error; + let details = source + .source() + .map(|e| format!(" ({e})")) + .unwrap_or_default(); + + Self::TransportError { source, details } + } +} + +/// Result type for the ConnectionBuilder +pub type Result = std::result::Result; + +/// A builder that produces a connection that can be used with any of the gRPC +/// clients +/// +/// ```no_run +/// #[tokio::main] +/// # async fn main() { +/// use client_util::connection::Builder; +/// use std::time::Duration; +/// +/// let connection = Builder::new() +/// .timeout(Duration::from_secs(42)) +/// .user_agent("my_awesome_client") +/// .build("http://127.0.0.1:8082/") +/// .await +/// .expect("connection must succeed"); +/// # } +/// ``` +#[derive(Debug, Clone)] +pub struct Builder { + user_agent: String, + headers: Vec<(HeaderName, HeaderValue)>, + connect_timeout: Duration, + timeout: Duration, +} + +impl std::default::Default for Builder { + fn default() -> Self { + Self { + user_agent: USER_AGENT.into(), + connect_timeout: DEFAULT_CONNECT_TIMEOUT, + timeout: DEFAULT_TIMEOUT, + headers: Default::default(), + } + } +} + +impl Builder { + /// Create a new default builder + pub fn new() -> Self { + Default::default() + } + + /// Construct the [`Connection`] instance using the specified base URL. + pub async fn build(self, dst: D) -> Result + where + D: TryInto + Send, + { + let endpoint = self.create_endpoint(dst)?; + let channel = endpoint.connect().await?; + Ok(self.compose_middleware(channel, endpoint)) + } + + /// Construct the [`Connection`] instance using the specified base URL and custom connector. + pub async fn build_with_connector(self, dst: D, connector: C) -> Result + where + D: TryInto + Send, + C: MakeConnection + Send + 'static, + C::Connection: Unpin + Send + 'static, + C::Future: Send + 'static, + Box: From + Send + 'static, + { + let endpoint = self.create_endpoint(dst)?; + let channel = endpoint.connect_with_connector(connector).await?; + Ok(self.compose_middleware(channel, endpoint)) + } + + fn create_endpoint(&self, dst: D) -> Result + where + D: TryInto + Send, + { + let endpoint = Endpoint::from(dst.try_into()?) + .user_agent(&self.user_agent)? + .connect_timeout(self.connect_timeout) + .timeout(self.timeout); + Ok(endpoint) + } + + fn compose_middleware(self, channel: Channel, endpoint: Endpoint) -> Connection { + let headers_map: HeaderMap = self.headers.iter().cloned().collect(); + + // Compose channel with new tower middleware stack + let grpc_connection = tower::ServiceBuilder::new() + .layer(SetRequestHeadersLayer::new(self.headers)) + .service(channel); + + let http_client = reqwest::Client::builder() + .connection_verbose(true) + .default_headers(headers_map) + .build() + .expect("reqwest::Client should have built"); + + let http_connection = HttpConnection::new(endpoint.uri().clone(), http_client); + + Connection::new(grpc_connection, http_connection) + } + + /// Set the `User-Agent` header sent by this client. + pub fn user_agent(self, user_agent: impl Into) -> Self { + Self { + user_agent: user_agent.into(), + ..self + } + } + + /// Sets a header to be included on all requests + pub fn header(self, header: impl Into, value: impl Into) -> Self { + let mut headers = self.headers; + headers.push((header.into(), value.into())); + Self { headers, ..self } + } + + /// Sets the maximum duration of time the client will wait for the IOx + /// server to accept the TCP connection before aborting the request. + /// + /// Note this does not bound the request duration - see + /// [`timeout`][Self::timeout]. + pub fn connect_timeout(self, timeout: Duration) -> Self { + Self { + connect_timeout: timeout, + ..self + } + } + + /// Bounds the total amount of time a single client HTTP request take before + /// being aborted. + /// + /// This timeout includes: + /// + /// - Establishing the TCP connection (see [`connect_timeout`]) + /// - Sending the HTTP request + /// - Waiting for, and receiving the entire HTTP response + /// + /// [`connect_timeout`]: Self::connect_timeout + pub fn timeout(self, timeout: Duration) -> Self { + Self { timeout, ..self } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use reqwest::Method; + + #[test] + fn test_builder_cloneable() { + // Clone is used by Conductor. + fn assert_clone(_t: T) {} + assert_clone(Builder::default()) + } + + #[tokio::test(flavor = "multi_thread")] + async fn headers_are_set() { + let mut mock_server = mockito::Server::new_async().await; + let url = mock_server.url(); + + let http_connection = Builder::new() + .header( + HeaderName::from_static("foo"), + HeaderValue::from_static("bar"), + ) + .build(&url) + .await + .unwrap() + .into_http_connection(); + + let url = format!("{url}/the_api"); + println!("Sending to {url}"); + + let m = mock_server + .mock("POST", "/the_api") + .with_status(201) + .with_body("world") + .match_header("FOO", "bar") + .create_async() + .await; + + http_connection + .client() + .request(Method::POST, &url) + .send() + .await + .expect("Error making http request"); + + m.assert_async().await; + } +} diff --git a/client_util/src/lib.rs b/client_util/src/lib.rs new file mode 100644 index 0000000..74a1a34 --- /dev/null +++ b/client_util/src/lib.rs @@ -0,0 +1,32 @@ +//! Shared InfluxDB IOx API client functionality +#![deny( + rustdoc::broken_intra_doc_links, + rustdoc::bare_urls, + rust_2018_idioms, + missing_debug_implementations, + unreachable_pub +)] +#![warn( + missing_docs, + clippy::todo, + clippy::dbg_macro, + clippy::clone_on_ref_ptr, + // See https://github.com/influxdata/influxdb_iox/pull/1671 + clippy::future_not_send, + clippy::todo, + clippy::dbg_macro, + unused_crate_dependencies +)] +#![allow(clippy::missing_docs_in_private_items)] + +// Workaround for "unused crate" lint false positives. +use workspace_hack as _; + +/// Builder for constructing connections for use with the various gRPC clients +pub mod connection; + +/// Helper to set client headers. +pub mod tower; + +/// Namespace <--> org/bucket utilities +pub mod namespace_translation; diff --git a/client_util/src/namespace_translation.rs b/client_util/src/namespace_translation.rs new file mode 100644 index 0000000..53f011e --- /dev/null +++ b/client_util/src/namespace_translation.rs @@ -0,0 +1,90 @@ +//! Contains logic to map namespace back/forth to org/bucket + +use thiserror::Error; + +/// Errors returned by namespace parsing +#[allow(missing_docs)] +#[derive(Debug, Error)] +pub enum Error { + #[error("Invalid namespace '{namespace}': {reason}")] + InvalidNamespace { namespace: String, reason: String }, +} + +impl Error { + fn new(namespace: impl Into, reason: impl Into) -> Self { + Self::InvalidNamespace { + namespace: namespace.into(), + reason: reason.into(), + } + } +} + +/// Splits up the namespace name into org_id and bucket_id +pub fn split_namespace(namespace: &str) -> Result<(&str, &str), Error> { + let mut iter = namespace.split('_'); + let org_id = iter.next().ok_or_else(|| Error::new(namespace, "empty"))?; + + if org_id.is_empty() { + return Err(Error::new(namespace, "No org_id found")); + } + + let bucket_id = iter + .next() + .ok_or_else(|| Error::new(namespace, "Could not find '_'"))?; + + if bucket_id.is_empty() { + return Err(Error::new(namespace, "No bucket_id found")); + } + + if iter.next().is_some() { + return Err(Error::new(namespace, "More than one '_'")); + } + + Ok((org_id, bucket_id)) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn split_good() { + assert_eq!(split_namespace("foo_bar").unwrap(), ("foo", "bar")); + } + + #[test] + #[should_panic(expected = "No org_id found")] + fn split_bad_empty() { + split_namespace("").unwrap(); + } + + #[test] + #[should_panic(expected = "No org_id found")] + fn split_bad_only_underscore() { + split_namespace("_").unwrap(); + } + + #[test] + #[should_panic(expected = "No org_id found")] + fn split_bad_empty_org_id() { + split_namespace("_ff").unwrap(); + } + + #[test] + #[should_panic(expected = "No bucket_id found")] + fn split_bad_empty_bucket_id() { + split_namespace("ff_").unwrap(); + } + + #[test] + #[should_panic(expected = "More than one '_'")] + fn split_too_many() { + split_namespace("ff_bf_").unwrap(); + } + + #[test] + #[should_panic(expected = "More than one '_'")] + fn split_way_too_many() { + split_namespace("ff_bf_dfd_3_f").unwrap(); + } +} diff --git a/client_util/src/tower.rs b/client_util/src/tower.rs new file mode 100644 index 0000000..73eae36 --- /dev/null +++ b/client_util/src/tower.rs @@ -0,0 +1,79 @@ +use http::header::HeaderName; +use http::{HeaderValue, Request, Response}; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tower::{Layer, Service}; + +/// `SetRequestHeadersLayer` sets the provided headers on all requests flowing through it +/// unless they're already set +#[derive(Debug, Clone)] +pub(crate) struct SetRequestHeadersLayer { + headers: Arc>, +} + +impl SetRequestHeadersLayer { + pub(crate) fn new(headers: Vec<(HeaderName, HeaderValue)>) -> Self { + Self { + headers: Arc::new(headers), + } + } +} + +impl Layer for SetRequestHeadersLayer { + type Service = SetRequestHeadersService; + + fn layer(&self, service: S) -> Self::Service { + SetRequestHeadersService { + service, + headers: Arc::clone(&self.headers), + } + } +} + +/// SetRequestHeadersService wraps an inner tower::Service and sets the provided +/// headers on requests flowing through it +#[derive(Debug, Clone)] +pub struct SetRequestHeadersService { + service: S, + headers: Arc>, +} + +impl SetRequestHeadersService { + /// Create sevice from inner service and headers. + pub fn new(service: S, headers: Vec<(HeaderName, HeaderValue)>) -> Self { + Self { + service, + headers: Arc::new(headers), + } + } + + /// De-construct service into parts. + /// + /// The can be used to call [`new`](Self::new) again. + pub fn into_parts(self) -> (S, Arc>) { + let SetRequestHeadersService { service, headers } = self; + + (service, headers) + } +} + +impl Service> for SetRequestHeadersService +where + S: Service, Response = Response>, +{ + type Response = Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, mut request: Request) -> Self::Future { + let headers = request.headers_mut(); + for (name, value) in self.headers.iter() { + headers.insert(name, value.clone()); + } + self.service.call(request) + } +} diff --git a/data_types/Cargo.toml b/data_types/Cargo.toml new file mode 100644 index 0000000..c38745c --- /dev/null +++ b/data_types/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "data_types" +description = "Shared data types" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[lints] +workspace = true + +[dependencies] +arrow-buffer = { workspace = true } +bytes = "1.5" +chrono = { version = "0.4", default-features = false } +croaring = "1.0.0" +influxdb-line-protocol = { path = "../influxdb_line_protocol" } +iox_time = { path = "../iox_time" } +generated_types = { path = "../generated_types" } +murmur3 = "0.5.2" +observability_deps = { path = "../observability_deps" } +once_cell = "1" +ordered-float = "4" +percent-encoding = "2.3.1" +prost = { workspace = true } +schema = { path = "../schema" } +serde_json = "1.0" +siphasher = "1.0" +sha2 = { version = "0.10", default-features = false } +snafu = "0.8" +sqlx = { version = "0.7.3", features = ["runtime-tokio-rustls", "postgres", "uuid"] } +thiserror = "1.0.56" +uuid = { version = "1", features = ["v4"] } +workspace-hack = { version = "0.1", path = "../workspace-hack" } + +[dev-dependencies] # In alphabetical order +assert_matches = "1" +paste = "1.0.14" +proptest = { version = "1.4.0", default-features = false } +test_helpers = { path = "../test_helpers" } +hex = "0.4.2" diff --git a/data_types/src/columns.rs b/data_types/src/columns.rs new file mode 100644 index 0000000..1c6b0a9 --- /dev/null +++ b/data_types/src/columns.rs @@ -0,0 +1,997 @@ +//! Types having to do with columns. + +use super::TableId; +use generated_types::influxdata::iox::{column_type::v1 as proto, gossip}; +use influxdb_line_protocol::FieldValue; +use schema::{builder::SchemaBuilder, sort::SortKey, InfluxColumnType, InfluxFieldType, Schema}; +use snafu::Snafu; +use std::cmp::Ordering; +use std::collections::HashSet; +use std::{ + collections::{BTreeMap, BTreeSet, HashMap}, + convert::TryFrom, + ops::Deref, + sync::Arc, +}; + +/// Unique ID for a `Column` +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type)] +#[sqlx(transparent)] +pub struct ColumnId(i64); + +#[allow(missing_docs)] +impl ColumnId { + pub fn new(v: i64) -> Self { + Self(v) + } + pub fn get(&self) -> i64 { + self.0 + } +} + +/// Column definitions for a table indexed by their name +#[derive(Debug, Clone, Eq, PartialEq, Hash, Default)] +pub struct ColumnsByName(BTreeMap, ColumnSchema>); + +impl From, ColumnSchema>> for ColumnsByName { + fn from(value: BTreeMap, ColumnSchema>) -> Self { + Self(value) + } +} + +impl ColumnsByName { + /// Create a new instance holding the given [`Column`]s. + pub fn new(columns: impl IntoIterator) -> Self { + Self( + columns + .into_iter() + .map(|c| { + ( + Arc::from(c.name), + ColumnSchema { + id: c.id, + column_type: c.column_type, + }, + ) + }) + .collect(), + ) + } + + /// Add the given column name and schema to this set of columns. + /// + /// # Panics + /// + /// This method panics if a column of the same name already exists in `self`. + pub fn add_column(&mut self, column_name: impl Into>, column_schema: ColumnSchema) { + let old = self.0.insert(column_name.into(), column_schema); + assert!(old.is_none()); + } + + /// Iterate over the names and columns. + pub fn iter(&self) -> impl Iterator, &ColumnSchema)> { + self.0.iter() + } + + /// Whether a column with this name is in the set. + pub fn contains_column_name(&self, name: &str) -> bool { + self.0.contains_key(name) + } + + /// Return number of columns in the set. + pub fn column_count(&self) -> usize { + self.0.len() + } + + /// Return the set of column names. Used in combination with a write operation's + /// column names to determine whether a write would exceed the max allowed columns. + pub fn names(&self) -> BTreeSet<&str> { + self.0.keys().map(|name| name.as_ref()).collect() + } + + /// Return an iterator of the set of column IDs. + pub fn ids(&self) -> impl Iterator + '_ { + self.0.values().map(|c| c.id) + } + + /// Return column ids of the given column names + /// + /// # Panics + /// + /// Panics if any of the names are not found in this set. + pub fn ids_for_names(&self, names: impl IntoIterator + Send) -> SortKeyIds + where + T: AsRef, + { + SortKeyIds::from(names.into_iter().map(|name| { + let name = name.as_ref(); + self.get(name) + .unwrap_or_else(|| panic!("column name not found: {}", name)) + .id + .get() + })) + } + + /// Get a column by its name. + pub fn get(&self, name: &str) -> Option<&ColumnSchema> { + self.0.get(name) + } + + /// Get the `ColumnId` for the time column, if present (a table created through + /// `table_load_or_create` will always have a time column). + pub fn time_column_id(&self) -> Option { + self.get(schema::TIME_COLUMN_NAME).map(|column| column.id) + } + + /// Create `ID->name` map for columns. + pub fn id_map(&self) -> HashMap> { + self.0 + .iter() + .map(|(name, c)| (c.id, Arc::clone(name))) + .collect() + } +} + +impl IntoIterator for ColumnsByName { + type Item = (Arc, ColumnSchema); + type IntoIter = std::collections::btree_map::IntoIter, ColumnSchema>; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl FromIterator<(Arc, ColumnSchema)> for ColumnsByName { + fn from_iter, ColumnSchema)>>(iter: T) -> Self { + Self(BTreeMap::from_iter(iter)) + } +} + +// ColumnsByName is a newtype so that we can implement this `TryFrom` in this crate +impl TryFrom for Schema { + type Error = schema::builder::Error; + + fn try_from(value: ColumnsByName) -> Result { + let mut builder = SchemaBuilder::new(); + + for (column_name, column_schema) in value.into_iter() { + let t = InfluxColumnType::from(column_schema.column_type); + builder.influx_column(column_name.as_ref(), t); + } + + builder.build() + } +} + +/// Data object for a column +#[derive(Debug, Clone, sqlx::FromRow, Eq, PartialEq)] +pub struct Column { + /// the column id + pub id: ColumnId, + /// the table id the column is in + pub table_id: TableId, + /// the name of the column, which is unique in the table + pub name: String, + /// the logical type of the column + pub column_type: ColumnType, +} + +impl Column { + /// returns true if the column type is a tag + pub fn is_tag(&self) -> bool { + self.column_type == ColumnType::Tag + } + + /// returns true if the column type matches the line protocol field value type + pub fn matches_field_type(&self, field_value: &FieldValue<'_>) -> bool { + match field_value { + FieldValue::I64(_) => self.column_type == ColumnType::I64, + FieldValue::U64(_) => self.column_type == ColumnType::U64, + FieldValue::F64(_) => self.column_type == ColumnType::F64, + FieldValue::String(_) => self.column_type == ColumnType::String, + FieldValue::Boolean(_) => self.column_type == ColumnType::Bool, + } + } +} + +/// The column id and its type for a column +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +pub struct ColumnSchema { + /// the column id + pub id: ColumnId, + /// the column type + pub column_type: ColumnType, +} + +impl ColumnSchema { + /// returns true if the column is a tag + pub fn is_tag(&self) -> bool { + self.column_type == ColumnType::Tag + } + + /// returns true if the column matches the line protocol field value type + pub fn matches_field_type(&self, field_value: &FieldValue<'_>) -> bool { + matches!( + (field_value, self.column_type), + (FieldValue::I64(_), ColumnType::I64) + | (FieldValue::U64(_), ColumnType::U64) + | (FieldValue::F64(_), ColumnType::F64) + | (FieldValue::String(_), ColumnType::String) + | (FieldValue::Boolean(_), ColumnType::Bool) + ) + } + + /// Returns true if `mb_column` is of the same type as `self`. + pub fn matches_type(&self, mb_column_influx_type: InfluxColumnType) -> bool { + self.column_type == mb_column_influx_type + } +} + +impl TryFrom<&gossip::v1::Column> for ColumnSchema { + type Error = Box; + + fn try_from(v: &gossip::v1::Column) -> Result { + Ok(Self { + id: ColumnId::new(v.column_id), + column_type: ColumnType::try_from(v.column_type as i16)?, + }) + } +} + +/// The column data type +#[allow(missing_docs)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash, sqlx::Type)] +#[repr(i16)] +pub enum ColumnType { + I64 = 1, + U64 = 2, + F64 = 3, + Bool = 4, + String = 5, + Time = 6, + Tag = 7, +} + +impl ColumnType { + /// the short string description of the type + pub fn as_str(&self) -> &'static str { + match self { + Self::I64 => "i64", + Self::U64 => "u64", + Self::F64 => "f64", + Self::Bool => "bool", + Self::String => "string", + Self::Time => "time", + Self::Tag => "tag", + } + } +} + +impl std::fmt::Display for ColumnType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = self.as_str(); + + write!(f, "{s}") + } +} + +/// Errors deserialising a protobuf serialised [`ColumnType`]. +#[derive(Debug, Snafu)] +#[snafu(display("invalid column value"))] +#[allow(missing_copy_implementations)] +pub struct ColumnTypeProtoError {} + +impl TryFrom for ColumnType { + type Error = ColumnTypeProtoError; + + fn try_from(value: i16) -> Result { + match value { + x if x == Self::I64 as i16 => Ok(Self::I64), + x if x == Self::U64 as i16 => Ok(Self::U64), + x if x == Self::F64 as i16 => Ok(Self::F64), + x if x == Self::Bool as i16 => Ok(Self::Bool), + x if x == Self::String as i16 => Ok(Self::String), + x if x == Self::Time as i16 => Ok(Self::Time), + x if x == Self::Tag as i16 => Ok(Self::Tag), + _ => Err(ColumnTypeProtoError {}), + } + } +} + +impl From for ColumnType { + fn from(value: InfluxColumnType) -> Self { + match value { + InfluxColumnType::Tag => Self::Tag, + InfluxColumnType::Field(InfluxFieldType::Float) => Self::F64, + InfluxColumnType::Field(InfluxFieldType::Integer) => Self::I64, + InfluxColumnType::Field(InfluxFieldType::UInteger) => Self::U64, + InfluxColumnType::Field(InfluxFieldType::String) => Self::String, + InfluxColumnType::Field(InfluxFieldType::Boolean) => Self::Bool, + InfluxColumnType::Timestamp => Self::Time, + } + } +} + +impl From for InfluxColumnType { + fn from(value: ColumnType) -> Self { + match value { + ColumnType::I64 => Self::Field(InfluxFieldType::Integer), + ColumnType::U64 => Self::Field(InfluxFieldType::UInteger), + ColumnType::F64 => Self::Field(InfluxFieldType::Float), + ColumnType::Bool => Self::Field(InfluxFieldType::Boolean), + ColumnType::String => Self::Field(InfluxFieldType::String), + ColumnType::Time => Self::Timestamp, + ColumnType::Tag => Self::Tag, + } + } +} + +impl PartialEq for ColumnType { + fn eq(&self, got: &InfluxColumnType) -> bool { + match self { + Self::I64 => matches!(got, InfluxColumnType::Field(InfluxFieldType::Integer)), + Self::U64 => matches!(got, InfluxColumnType::Field(InfluxFieldType::UInteger)), + Self::F64 => matches!(got, InfluxColumnType::Field(InfluxFieldType::Float)), + Self::Bool => matches!(got, InfluxColumnType::Field(InfluxFieldType::Boolean)), + Self::String => matches!(got, InfluxColumnType::Field(InfluxFieldType::String)), + Self::Time => matches!(got, InfluxColumnType::Timestamp), + Self::Tag => matches!(got, InfluxColumnType::Tag), + } + } +} + +/// Returns the `ColumnType` for the passed in line protocol `FieldValue` type +pub fn column_type_from_field(field_value: &FieldValue<'_>) -> ColumnType { + match field_value { + FieldValue::I64(_) => ColumnType::I64, + FieldValue::U64(_) => ColumnType::U64, + FieldValue::F64(_) => ColumnType::F64, + FieldValue::String(_) => ColumnType::String, + FieldValue::Boolean(_) => ColumnType::Bool, + } +} + +impl TryFrom for ColumnType { + type Error = &'static str; + + fn try_from(value: proto::ColumnType) -> Result { + Ok(match value { + proto::ColumnType::I64 => Self::I64, + proto::ColumnType::U64 => Self::U64, + proto::ColumnType::F64 => Self::F64, + proto::ColumnType::Bool => Self::Bool, + proto::ColumnType::String => Self::String, + proto::ColumnType::Time => Self::Time, + proto::ColumnType::Tag => Self::Tag, + proto::ColumnType::Unspecified => return Err("unknown column type"), + }) + } +} + +impl From for proto::ColumnType { + fn from(value: ColumnType) -> Self { + match value { + ColumnType::I64 => Self::I64, + ColumnType::U64 => Self::U64, + ColumnType::F64 => Self::F64, + ColumnType::Bool => Self::Bool, + ColumnType::String => Self::String, + ColumnType::Time => Self::Time, + ColumnType::Tag => Self::Tag, + } + } +} + +/// Set of columns and used as Set data type. +/// +/// # Data Structure +/// This is internally implemented as a sorted vector. The sorting allows for fast [`PartialEq`]/[`Eq`]/[`Hash`] and +/// ensures that the PostgreSQL data is deterministic. Note that PostgreSQL does NOT have a set type at the moment, so +/// this is stored as an array. +#[derive(Debug, Clone, PartialEq, Eq, Hash, sqlx::Type)] +#[sqlx(transparent, no_pg_array)] +pub struct ColumnSet(Vec); + +impl ColumnSet { + /// Create new column set. + /// + /// The order of the passed columns will NOT be preserved. + /// + /// # Panic + /// Panics when the set of passed columns contains duplicates. + pub fn new(columns: I) -> Self + where + I: IntoIterator, + { + let mut columns: Vec = columns.into_iter().collect(); + columns.sort(); + + assert!( + columns.windows(2).all(|w| w[0] != w[1]), + "set contains duplicates" + ); + + columns.shrink_to_fit(); + + Self(columns) + } + + /// Create a new empty [`ColumnSet`] + pub fn empty() -> Self { + Self(Vec::new()) + } + + /// Estimate the memory consumption of this object and its contents + pub fn size(&self) -> usize { + std::mem::size_of_val(self) + (std::mem::size_of::() * self.0.capacity()) + } + + /// The set is empty or not + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Computes the union of `self` and `other` + pub fn union(&mut self, other: &Self) { + let mut insert_idx = 0; + let mut src_idx = 0; + + while insert_idx < self.0.len() && src_idx < other.0.len() { + let s = self.0[insert_idx]; + let o = other.0[src_idx]; + + match s.cmp(&o) { + Ordering::Less => insert_idx += 1, + Ordering::Equal => { + insert_idx += 1; + src_idx += 1; + } + Ordering::Greater => { + self.0.insert(insert_idx, o); + insert_idx += 1; + src_idx += 1; + } + } + } + self.0.extend_from_slice(&other.0[src_idx..]); + } + + /// Returns the indices and ids in `self` that are present in both `self` and `other` + /// + /// ``` + /// # use data_types::{ColumnId, ColumnSet}; + /// let a = ColumnSet::new([1, 2, 4, 6, 7].into_iter().map(ColumnId::new)); + /// let b = ColumnSet::new([2, 4, 6].into_iter().map(ColumnId::new)); + /// + /// assert_eq!( + /// a.intersect(&b).collect::>(), + /// vec![(1, b[0]), (2, b[1]), (3, b[2])] + /// ) + /// ``` + pub fn intersect<'a>( + &'a self, + other: &'a Self, + ) -> impl Iterator + 'a { + let mut left_idx = 0; + let mut right_idx = 0; + std::iter::from_fn(move || loop { + let s = self.0.get(left_idx)?; + let o = other.get(right_idx)?; + + match s.cmp(o) { + Ordering::Less => left_idx += 1, + Ordering::Greater => right_idx += 1, + Ordering::Equal => { + let t = left_idx; + left_idx += 1; + right_idx += 1; + return Some((t, *s)); + } + } + }) + } +} + +impl From for Vec { + fn from(set: ColumnSet) -> Self { + set.0 + } +} + +impl Deref for ColumnSet { + type Target = [ColumnId]; + + fn deref(&self) -> &Self::Target { + self.0.deref() + } +} + +/// Set of sorted column IDs in a specific given order at creation time, to be used as a +/// [`SortKey`] by looking up the column names in the table's schema. +#[derive(Debug, Clone, PartialEq, Eq, Hash, sqlx::Type, Default)] +#[sqlx(transparent, no_pg_array)] +pub struct SortKeyIds(Vec); + +impl SortKeyIds { + /// Create new sorted column set. + /// + /// The order of the passed columns will be preserved. + /// + /// # Panic + /// Panics when the set of passed columns contains duplicates. + pub fn new(columns: I) -> Self + where + I: IntoIterator, + { + let mut columns: Vec = columns.into_iter().collect(); + + // Validate the ID set contains no duplicates. + // + // This validates an invariant in debug builds, skipping the cost + // for release builds. + if cfg!(debug_assertions) { + SortKeyIds::check_for_deplicates(&columns); + } + + // Must continue with columns in original order + columns.shrink_to_fit(); + + Self(columns) + } + + /// Given another set of sort key IDs, merge them together and, if needed, return a value to + /// use to update the catalog. + /// + /// If `other` contains any column IDs that are not present in `self`, create a new + /// `SortKeyIds` instance that includes the new columns in `other` (in the same order they + /// appear in `other`) appended to the existing columns, but keeping the time column ID last. + /// + /// If existing columns appear in `self` in a different order than they appear in `other`, the + /// order in `self` takes precedence and remains unchanged. + /// + /// If `self` contains all the sort keys in `other` already (regardless of order), this will + /// return `None` as no update to the catalog is needed. + pub fn maybe_append(&self, other: &Self, time_column_id: ColumnId) -> Option { + let existing_columns_without_time = self + .iter() + .cloned() + .filter(|&column_id| column_id != time_column_id); + + let mut new_columns = other + .iter() + .cloned() + .filter(|column_id| !self.contains(column_id)) + .peekable(); + + if new_columns.peek().is_none() { + None + } else { + Some(SortKeyIds::new( + existing_columns_without_time + .chain(new_columns) + .chain(std::iter::once(time_column_id)), + )) + } + } + + /// Estimate the memory consumption of this object and its contents + pub fn size(&self) -> usize { + std::mem::size_of_val(self) + (std::mem::size_of::() * self.0.capacity()) + } + + /// Build a [`SortKey`] from [`SortKeyIds`]; looking up column names in the provided + /// [`ColumnsByName`] map by converting it to a `HashMap. If you already have + /// an id-to-name column map, use [`SortKeyIds::to_sort_key_using_map`] instead. + /// + /// If you have a [`Partition`][super::Partition], it may be more convenient to call the + /// [`Partition::sort_key`][super::Partition::sort_key] method instead! + /// + /// # Panics + /// + /// Will panic if an ID isn't found in the column map. + pub fn to_sort_key(&self, columns: &ColumnsByName) -> SortKey { + let column_id_map = columns.id_map(); + self.to_sort_key_using_map(&column_id_map) + } + + /// Build a [`SortKey`] from [`SortKeyIds`]; looking up column names in the provided + /// [`HashMap`] map. + /// + /// If you have a [`Partition`][super::Partition], it may be more convenient to call the + /// [`Partition::sort_key`][super::Partition::sort_key] method instead! + /// + /// # Panics + /// + /// Will panic if an ID isn't found in the column map. + pub fn to_sort_key_using_map(&self, column_id_map: &HashMap>) -> SortKey { + SortKey::from_columns(self.0.iter().map(|id| { + Arc::clone( + column_id_map.get(id).unwrap_or_else(|| { + panic!("cannot find column names for sort key id {}", id.get()) + }), + ) + })) + } + + /// Returns `true` if `other` is a monotonic update of `self`. + /// + /// # Panics + /// + /// Assumes "time" is the last column in both sets, and panics if the last + /// columns are not identical. + pub fn is_monotonic_update(&self, other: &Self) -> bool { + // The SortKeyIds always reference the time column last (if set). + if self.0.last().is_some() { + assert_eq!( + self.0.last(), + other.last(), + "last column in sort IDs must be time, and cannot change" + ); + } + + // Ensure the values in other are a prefix match, with the exception of + // the last "time" column. + self.0.len() <= other.len() + && self + .0 + .iter() + .take(self.0.len().saturating_sub(1)) + .zip(other.iter()) + .all(|(a, b)| a == b) + } + + fn check_for_deplicates(columns: &[ColumnId]) { + let mut column_ids: HashSet = HashSet::with_capacity(columns.len()); + for c in columns { + match column_ids.get(&c.0) { + Some(_) => { + panic!("set contains duplicates"); + } + _ => { + column_ids.insert(c.0); + } + } + } + } +} + +impl From for Vec { + fn from(set: SortKeyIds) -> Self { + set.0 + } +} + +impl Deref for SortKeyIds { + type Target = [ColumnId]; + + fn deref(&self) -> &Self::Target { + self.0.deref() + } +} + +impl From for SortKeyIds +where + I: IntoIterator, +{ + fn from(ids: I) -> Self { + Self::new(ids.into_iter().map(ColumnId::new).collect::>()) + } +} + +impl From<&SortKeyIds> for Vec { + fn from(val: &SortKeyIds) -> Self { + val.0.iter().map(|id| id.get()).collect() + } +} + +impl From<&SortKeyIds> for generated_types::influxdata::iox::catalog::v1::SortKeyIds { + fn from(val: &SortKeyIds) -> Self { + generated_types::influxdata::iox::catalog::v1::SortKeyIds { + array_sort_key_ids: val.into(), + } + } +} + +#[cfg(test)] +mod tests { + use assert_matches::assert_matches; + + use super::*; + + #[test] + #[should_panic = "set contains duplicates"] + fn test_column_set_duplicates() { + ColumnSet::new([ColumnId::new(1), ColumnId::new(2), ColumnId::new(1)]); + } + + #[test] + fn test_column_set_eq() { + let set_1 = ColumnSet::new([ColumnId::new(1), ColumnId::new(2)]); + let set_2 = ColumnSet::new([ColumnId::new(2), ColumnId::new(1)]); + assert_eq!(set_1, set_2); + } + + #[test] + fn test_column_set_union_intersect() { + let a = ColumnSet::new([1, 2, 5, 7].into_iter().map(ColumnId::new)); + let b = ColumnSet::new([1, 5, 6, 7, 8].into_iter().map(ColumnId::new)); + + let mut t = ColumnSet::empty(); + t.union(&a); + assert_eq!(t, a); + + assert_eq!( + t.intersect(&a).collect::>(), + vec![(0, a[0]), (1, a[1]), (2, a[2]), (3, a[3])] + ); + + t.union(&b); + let expected = ColumnSet::new([1, 2, 5, 6, 7, 8].into_iter().map(ColumnId::new)); + assert_eq!(t, expected); + + assert_eq!( + t.intersect(&a).collect::>(), + vec![(0, a[0]), (1, a[1]), (2, a[2]), (4, a[3])] + ); + + assert_eq!( + t.intersect(&b).collect::>(), + vec![(0, b[0]), (2, b[1]), (3, b[2]), (4, b[3]), (5, b[4])] + ); + } + + #[test] + #[should_panic = "set contains duplicates"] + fn test_sorted_column_set_duplicates() { + SortKeyIds::new([ + ColumnId::new(2), + ColumnId::new(1), + ColumnId::new(3), + ColumnId::new(1), + ]); + } + + #[test] + fn test_sorted_column_set() { + let set = SortKeyIds::new([ColumnId::new(2), ColumnId::new(1), ColumnId::new(3)]); + // verify the order is preserved + assert_eq!(set[0], ColumnId::new(2)); + assert_eq!(set[1], ColumnId::new(1)); + assert_eq!(set[2], ColumnId::new(3)); + } + + #[test] + fn test_column_schema() { + assert_eq!( + ColumnType::try_from(proto::ColumnType::I64).unwrap(), + ColumnType::I64, + ); + assert_eq!( + ColumnType::try_from(proto::ColumnType::U64).unwrap(), + ColumnType::U64, + ); + assert_eq!( + ColumnType::try_from(proto::ColumnType::F64).unwrap(), + ColumnType::F64, + ); + assert_eq!( + ColumnType::try_from(proto::ColumnType::Bool).unwrap(), + ColumnType::Bool, + ); + assert_eq!( + ColumnType::try_from(proto::ColumnType::String).unwrap(), + ColumnType::String, + ); + assert_eq!( + ColumnType::try_from(proto::ColumnType::Time).unwrap(), + ColumnType::Time, + ); + assert_eq!( + ColumnType::try_from(proto::ColumnType::Tag).unwrap(), + ColumnType::Tag, + ); + + assert!(ColumnType::try_from(proto::ColumnType::Unspecified).is_err()); + } + + #[test] + fn test_gossip_proto_conversion() { + let proto = gossip::v1::Column { + name: "bananas".to_string(), + column_id: 42, + column_type: gossip::v1::column::ColumnType::String as _, + }; + + let got = ColumnSchema::try_from(&proto).expect("should succeed"); + assert_matches!(got, ColumnSchema{id, column_type} => { + assert_eq!(id.get(), 42); + assert_eq!(column_type, ColumnType::String); + }); + } + + #[test] + fn test_gossip_proto_conversion_invalid_type() { + let proto = gossip::v1::Column { + name: "bananas".to_string(), + column_id: 42, + column_type: 42, + }; + + ColumnSchema::try_from(&proto).expect_err("should succeed"); + } + + #[test] + fn test_columns_by_names_exist() { + let columns = build_columns_by_names(); + + let ids = columns.ids_for_names(["foo", "bar"]); + assert_eq!(ids, SortKeyIds::from([1, 2])); + } + + #[test] + fn test_columns_by_names_exist_different_order() { + let columns = build_columns_by_names(); + + let ids = columns.ids_for_names(["bar", "foo"]); + assert_eq!(ids, SortKeyIds::from([2, 1])); + } + + #[test] + #[should_panic = "column name not found: baz"] + fn test_columns_by_names_not_exist() { + let columns = build_columns_by_names(); + columns.ids_for_names(["foo", "baz"]); + } + + fn build_columns_by_names() -> ColumnsByName { + let mut columns: BTreeMap, ColumnSchema> = BTreeMap::new(); + columns.insert( + "foo".into(), + ColumnSchema { + id: ColumnId::new(1), + column_type: ColumnType::I64, + }, + ); + columns.insert( + "bar".into(), + ColumnSchema { + id: ColumnId::new(2), + column_type: ColumnType::I64, + }, + ); + columns.insert( + "time".into(), + ColumnSchema { + id: ColumnId::new(3), + column_type: ColumnType::Time, + }, + ); + columns.insert( + "tag1".into(), + ColumnSchema { + id: ColumnId::new(4), + column_type: ColumnType::Tag, + }, + ); + + ColumnsByName(columns) + } + + // panic if the sort_key_ids are not found in the columns + #[test] + #[should_panic(expected = "cannot find column names for sort key id 3")] + fn test_panic_build_sort_key_from_ids_and_map() { + // table columns + let uno = ColumnSchema { + id: ColumnId::new(1), + column_type: ColumnType::Tag, + }; + let dos = ColumnSchema { + id: ColumnId::new(2), + column_type: ColumnType::Tag, + }; + let mut column_map = ColumnsByName::default(); + column_map.add_column("uno", uno); + column_map.add_column("dos", dos); + + // sort_key_ids include some columns that are not in the columns + let sort_key_ids = SortKeyIds::from([2, 3]); + sort_key_ids.to_sort_key(&column_map); + } + + #[test] + fn test_build_sort_key_from_ids_and_map() { + // table columns + let uno = ColumnSchema { + id: ColumnId::new(1), + column_type: ColumnType::Tag, + }; + let dos = ColumnSchema { + id: ColumnId::new(2), + column_type: ColumnType::Tag, + }; + let tres = ColumnSchema { + id: ColumnId::new(3), + column_type: ColumnType::Tag, + }; + let mut column_map = ColumnsByName::default(); + column_map.add_column("uno", uno); + column_map.add_column("dos", dos); + column_map.add_column("tres", tres); + + // sort_key_ids is empty + let sort_key_ids = SortKeyIds::default(); + let sort_key = sort_key_ids.to_sort_key(&column_map); + assert_eq!(sort_key, SortKey::empty()); + + // sort_key_ids include all columns and in the same order + let sort_key_ids = SortKeyIds::from([1, 2, 3]); + let sort_key = sort_key_ids.to_sort_key(&column_map); + assert_eq!(sort_key, SortKey::from_columns(vec!["uno", "dos", "tres"])); + + // sort_key_ids include all columns but in different order + let sort_key_ids = SortKeyIds::from([2, 3, 1]); + let sort_key = sort_key_ids.to_sort_key(&column_map); + assert_eq!(sort_key, SortKey::from_columns(vec!["dos", "tres", "uno"])); + + // sort_key_ids include some columns + let sort_key_ids = SortKeyIds::from([2, 3]); + let sort_key = sort_key_ids.to_sort_key(&column_map); + assert_eq!(sort_key, SortKey::from_columns(vec!["dos", "tres"])); + + // sort_key_ids include some columns in different order + let sort_key_ids = SortKeyIds::from([3, 1]); + let sort_key = sort_key_ids.to_sort_key(&column_map); + assert_eq!(sort_key, SortKey::from_columns(vec!["tres", "uno"])); + } + + #[test] + fn test_sort_key_ids_round_trip_encoding() { + let original = SortKeyIds::from([1, 2, 3]); + + let encoded: generated_types::influxdata::iox::catalog::v1::SortKeyIds = (&original).into(); + + let decoded: SortKeyIds = encoded.array_sort_key_ids.into(); + assert_eq!(decoded, original); + } + + macro_rules! test_is_monotonic_update { + ( + $name:ident, + a = $a:expr, + b = $b:expr, + want = $want:expr + ) => { + paste::paste! { + #[test] + fn []() { + let a = SortKeyIds::new($a.into_iter().map(ColumnId::new)); + let b = SortKeyIds::new($b.into_iter().map(ColumnId::new)); + assert_eq!(a.is_monotonic_update(&b), $want) + } + } + }; + } + + test_is_monotonic_update!(equal, a = [42, 24, 1], b = [42, 24, 1], want = true); + + test_is_monotonic_update!(empty, a = [], b = [42, 24, 1], want = true); + + test_is_monotonic_update!( + longer_with_time, + a = [42, 24, 1], + b = [42, 24, 13, 1], + want = true + ); + + test_is_monotonic_update!(shorter_with_time, a = [42, 24, 1], b = [1], want = false); + + test_is_monotonic_update!( + mismatch_with_time, + a = [42, 24, 1], + b = [24, 42, 1], + want = false + ); + + test_is_monotonic_update!(mismatch, a = [42, 24, 1], b = [24, 42, 1], want = false); +} diff --git a/data_types/src/lib.rs b/data_types/src/lib.rs new file mode 100644 index 0000000..951af51 --- /dev/null +++ b/data_types/src/lib.rs @@ -0,0 +1,2799 @@ +//! Shared data types + +// `clippy::use_self` is deliberately excluded from the lints this crate uses. +// See . +#![warn( + missing_copy_implementations, + missing_debug_implementations, + missing_docs, + clippy::explicit_iter_loop, + // See https://github.com/influxdata/influxdb_iox/pull/1671 + clippy::future_not_send, + clippy::clone_on_ref_ptr, + clippy::todo, + clippy::dbg_macro, + unused_crate_dependencies +)] + +use thiserror::Error; +// Workaround for "unused crate" lint false positives. +use workspace_hack as _; + +mod columns; +pub use columns::*; +mod namespace_name; +pub use namespace_name::*; +pub mod partition_template; +use partition_template::*; +pub mod partition; +pub use partition::*; +pub mod sequence_number_set; +pub mod service_limits; +pub mod snapshot; + +pub use service_limits::*; + +use observability_deps::tracing::warn; +use schema::TIME_COLUMN_NAME; +use snafu::Snafu; +use std::{ + borrow::Borrow, + collections::{BTreeMap, BTreeSet, HashMap}, + convert::TryFrom, + fmt::{Display, Write}, + mem::{self, size_of_val}, + num::{FpCategory, NonZeroU64}, + ops::{Add, Deref, Sub}, + sync::Arc, +}; +use uuid::Uuid; + +/// Errors deserialising a protobuf serialised [`ParquetFile`]. +#[derive(Debug, Snafu)] +#[snafu(display("invalid compaction level value"))] +#[allow(missing_copy_implementations)] +pub struct CompactionLevelProtoError {} + +/// Compaction levels +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, sqlx::Type)] +#[repr(i16)] +pub enum CompactionLevel { + /// The starting compaction level for parquet files persisted by an Ingester is zero. + Initial = 0, + /// Level of files persisted by a Compactor that do not overlap with non-level-0 files. + FileNonOverlapped = 1, + /// Level of files persisted by a Compactor that are fully compacted and should not be + /// recompacted unless a new overlapping Initial level file arrives + Final = 2, +} + +impl Display for CompactionLevel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Initial => write!(f, "CompactionLevel::L0"), + Self::FileNonOverlapped => write!(f, "CompactionLevel::L1"), + Self::Final => write!(f, "CompactionLevel::L2"), + } + } +} + +impl TryFrom for CompactionLevel { + type Error = CompactionLevelProtoError; + + fn try_from(value: i32) -> Result { + match value { + x if x == Self::Initial as i32 => Ok(Self::Initial), + x if x == Self::FileNonOverlapped as i32 => Ok(Self::FileNonOverlapped), + x if x == Self::Final as i32 => Ok(Self::Final), + _ => Err(CompactionLevelProtoError {}), + } + } +} + +impl CompactionLevel { + /// When compacting files of this level, provide the level that the resulting file should be. + /// Does not exceed the maximum available level. + pub fn next(&self) -> Self { + match self { + Self::Initial => Self::FileNonOverlapped, + Self::FileNonOverlapped => Self::Final, + Self::Final => Self::Final, + } + } + + /// Return previous level + pub fn prev(&self) -> Self { + match self { + Self::Initial => Self::Initial, + Self::FileNonOverlapped => Self::Initial, + Self::Final => Self::FileNonOverlapped, + } + } + + /// Returns all levels + pub fn all() -> &'static [Self] { + &[Self::Initial, Self::FileNonOverlapped, Self::Final] + } + + /// Static name + pub fn name(&self) -> &'static str { + match self { + Self::Initial => "L0", + Self::FileNonOverlapped => "L1", + Self::Final => "L2", + } + } +} + +/// Unique ID for a `Namespace` +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type)] +#[sqlx(transparent)] +pub struct NamespaceId(i64); + +#[allow(missing_docs)] +impl NamespaceId { + pub const fn new(v: i64) -> Self { + Self(v) + } + pub fn get(&self) -> i64 { + self.0 + } +} + +impl std::fmt::Display for NamespaceId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Unique ID for a `Table` +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type)] +#[sqlx(transparent)] +pub struct TableId(i64); + +#[allow(missing_docs)] +impl TableId { + pub const fn new(v: i64) -> Self { + Self(v) + } + + pub fn get(&self) -> i64 { + self.0 + } + + pub const fn to_be_bytes(&self) -> [u8; 8] { + self.0.to_be_bytes() + } +} + +impl std::fmt::Display for TableId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +/// A sequence number from an ingester +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct SequenceNumber(u64); + +#[allow(missing_docs)] +impl SequenceNumber { + pub fn new(v: u64) -> Self { + Self(v) + } + pub fn get(&self) -> u64 { + self.0 + } +} + +impl Add for SequenceNumber { + type Output = Self; + + fn add(self, other: u64) -> Self { + Self(self.0 + other) + } +} + +impl Sub for SequenceNumber { + type Output = Self; + + fn sub(self, other: u64) -> Self { + Self(self.0 - other) + } +} + +/// A time in nanoseconds from epoch. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type)] +#[sqlx(transparent)] +pub struct Timestamp(i64); + +#[allow(missing_docs)] +impl Timestamp { + pub fn new(v: i64) -> Self { + Self(v) + } + + pub fn get(&self) -> i64 { + self.0 + } +} + +impl From for Timestamp { + fn from(time: iox_time::Time) -> Self { + Self::new(time.timestamp_nanos()) + } +} + +impl From for iox_time::Time { + fn from(time: Timestamp) -> iox_time::Time { + iox_time::Time::from_timestamp_nanos(time.get()) + } +} + +impl Add for Timestamp { + type Output = Self; + + fn add(self, other: Self) -> Self { + Self(self.0.checked_add(other.0).expect("timestamp wraparound")) + } +} + +impl Sub for Timestamp { + type Output = Self; + + fn sub(self, other: Self) -> Self { + Self(self.0.checked_sub(other.0).expect("timestamp wraparound")) + } +} + +impl Add for Timestamp { + type Output = Self; + + fn add(self, rhs: i64) -> Self::Output { + self + Self(rhs) + } +} + +impl Sub for Timestamp { + type Output = Self; + + fn sub(self, rhs: i64) -> Self::Output { + self - Self(rhs) + } +} + +/// Unique ID for a `ParquetFile` +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type)] +#[sqlx(transparent)] +pub struct ParquetFileId(i64); + +#[allow(missing_docs)] +impl ParquetFileId { + pub fn new(v: i64) -> Self { + Self(v) + } + pub fn get(&self) -> i64 { + self.0 + } +} + +impl std::fmt::Display for ParquetFileId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Use `self.number` to refer to each positional data point. + write!(f, "{}", self.0) + } +} + +/// Unique store UUID for a [`ParquetFile`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type)] +#[sqlx(transparent)] +pub struct ObjectStoreId(Uuid); + +#[allow(missing_docs)] +impl ObjectStoreId { + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + Self::from_uuid(Uuid::new_v4()) + } + + pub fn from_uuid(uuid: Uuid) -> Self { + Self(uuid) + } + + pub fn get_uuid(&self) -> Uuid { + self.0 + } +} + +impl std::fmt::Display for ObjectStoreId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl std::str::FromStr for ObjectStoreId { + type Err = uuid::Error; + + fn from_str(s: &str) -> Result { + let uuid = Uuid::parse_str(s)?; + Ok(Self::from_uuid(uuid)) + } +} + +/// Data object for a namespace +#[derive(Debug, Clone, PartialEq, sqlx::FromRow)] +pub struct Namespace { + /// The id of the namespace + pub id: NamespaceId, + /// The unique name of the namespace + pub name: String, + /// The retention period in ns. None represents infinite duration (i.e. never drop data). + pub retention_period_ns: Option, + /// The maximum number of tables that can exist in this namespace + pub max_tables: MaxTables, + /// The maximum number of columns per table in this namespace + pub max_columns_per_table: MaxColumnsPerTable, + /// When this file was marked for deletion. + pub deleted_at: Option, + /// The partition template to use for new tables in this namespace either created implicitly or + /// created without specifying a partition template. + pub partition_template: NamespacePartitionTemplateOverride, +} + +/// Schema collection for a namespace. This is an in-memory object useful for a schema +/// cache. +#[derive(Debug, Clone, PartialEq, Hash)] +pub struct NamespaceSchema { + /// the namespace id + pub id: NamespaceId, + /// the tables in the namespace by name + pub tables: BTreeMap, + /// The maximum number of tables permitted in this namespace. + pub max_tables: MaxTables, + /// the number of columns per table this namespace allows + pub max_columns_per_table: MaxColumnsPerTable, + /// The retention period in ns. + /// None represents infinite duration (i.e. never drop data). + pub retention_period_ns: Option, + /// The partition template to use for new tables in this namespace either created implicitly or + /// created without specifying a partition template. + pub partition_template: NamespacePartitionTemplateOverride, +} + +impl NamespaceSchema { + /// Start a new `NamespaceSchema` with empty `tables` but the rest of the information populated + /// from the given `Namespace`. + pub fn new_empty_from(namespace: &Namespace) -> Self { + let &Namespace { + id, + retention_period_ns, + max_tables, + max_columns_per_table, + ref partition_template, + .. + } = namespace; + + Self { + id, + tables: BTreeMap::new(), + max_tables, + max_columns_per_table, + retention_period_ns, + partition_template: partition_template.clone(), + } + } +} + +impl NamespaceSchema { + /// Estimated Size in bytes including `self`. + pub fn size(&self) -> usize { + std::mem::size_of_val(self) + + self + .tables + .iter() + .map(|(k, v)| size_of_val(k) + k.capacity() + v.size()) + .sum::() + } +} + +impl From<&NamespaceSchema> for generated_types::influxdata::iox::schema::v1::NamespaceSchema { + fn from(schema: &NamespaceSchema) -> Self { + namespace_schema_proto(schema.id, schema.tables.iter()) + } +} + +/// Generate [`NamespaceSchema`] protobuf from a `NamespaceId` and a list of tables. Useful to +/// filter the tables returned from an API request to a particular table without needing to clone +/// the whole `NamespaceSchema` to use the `From` impl. +pub fn namespace_schema_proto<'a>( + id: NamespaceId, + tables: impl Iterator, +) -> generated_types::influxdata::iox::schema::v1::NamespaceSchema { + use generated_types::influxdata::iox::schema::v1 as proto; + proto::NamespaceSchema { + id: id.get(), + tables: tables + .map(|(name, t)| (name.clone(), proto::TableSchema::from(t))) + .collect(), + } +} + +/// Data object for a table +#[derive(Debug, Clone, sqlx::FromRow, PartialEq)] +pub struct Table { + /// The id of the table + pub id: TableId, + /// The namespace id that the table is in + pub namespace_id: NamespaceId, + /// The name of the table, which is unique within the associated namespace + pub name: String, + /// The partition template to use for writes in this table. + pub partition_template: TablePartitionTemplateOverride, +} + +/// Serialise a [`Table`] object into its protobuf representation. +impl From

for generated_types::influxdata::iox::table::v1::Table { + fn from(value: Table) -> Self { + generated_types::influxdata::iox::table::v1::Table { + id: value.id.get(), + name: value.name, + namespace_id: value.namespace_id.get(), + partition_template: value.partition_template.as_proto().cloned(), + } + } +} + +/// Column definitions for a table +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct TableSchema { + /// the table id + pub id: TableId, + + /// The partition template to use for writes in this table. + pub partition_template: TablePartitionTemplateOverride, + + /// the table's columns by their name + pub columns: ColumnsByName, +} + +impl TableSchema { + /// Initialize new `TableSchema` from the information in the given `Table`. + pub fn new_empty_from(table: &Table) -> Self { + Self { + id: table.id, + partition_template: table.partition_template.clone(), + columns: ColumnsByName::default(), + } + } + + /// Add `col` to this table schema. + /// + /// # Panics + /// + /// This method panics if a column of the same name already exists in + /// `self`, or if `col` references a different `table_id`. + pub fn add_column(&mut self, col: Column) { + let Column { + id, + name, + column_type, + table_id, + } = col; + + assert_eq!(table_id, self.id); + + let column_schema = ColumnSchema { id, column_type }; + self.add_column_schema(name, column_schema); + } + + /// Add the name and column schema to this table's schema. + /// + /// # Panics + /// + /// This method panics if a column of the same name already exists in + /// `self`. + pub fn add_column_schema( + &mut self, + column_name: impl Into>, + column_schema: ColumnSchema, + ) { + self.columns.add_column(column_name, column_schema); + } + + /// Estimated Size in bytes including `self`. + pub fn size(&self) -> usize { + size_of_val(self) + + self + .columns + .iter() + .map(|(k, v)| size_of_val(k) + k.as_ref().len() + size_of_val(v)) + .sum::() + } + + /// Create `ID->name` map for columns. + pub fn column_id_map(&self) -> HashMap> { + self.columns.id_map() + } + + /// Whether a column with this name is in the schema. + pub fn contains_column_name(&self, name: &str) -> bool { + self.columns.contains_column_name(name) + } + + /// Return the set of column names for this table. Used in combination with a write operation's + /// column names to determine whether a write would exceed the max allowed columns. + pub fn column_names(&self) -> BTreeSet<&str> { + self.columns.names() + } + + /// Return number of columns of the table + pub fn column_count(&self) -> usize { + self.columns.column_count() + } +} + +impl From<&TableSchema> for generated_types::influxdata::iox::schema::v1::TableSchema { + fn from(table_schema: &TableSchema) -> Self { + use generated_types::influxdata::iox::schema::v1 as proto; + + Self { + id: table_schema.id.get(), + columns: table_schema + .columns + .iter() + .map(|(name, c)| { + ( + name.to_string(), + proto::ColumnSchema { + id: c.id.get(), + column_type: c.column_type as i32, + }, + ) + }) + .collect(), + } + } +} + +/// Data recorded when compaction skips a partition. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::FromRow)] +pub struct SkippedCompaction { + /// the partition + pub partition_id: PartitionId, + /// the reason compaction was skipped + pub reason: String, + /// when compaction was skipped + pub skipped_at: Timestamp, + /// estimated memory budget + pub estimated_bytes: i64, + /// limit on memory budget + pub limit_bytes: i64, + /// num files selected to compact + pub num_files: i64, + /// limit on num files + pub limit_num_files: i64, + /// limit on num files for the first file in a partition + pub limit_num_files_first_in_partition: i64, +} + +impl From + for generated_types::influxdata::iox::skipped_compaction::v1::SkippedCompaction +{ + fn from(skipped_compaction: SkippedCompaction) -> Self { + let SkippedCompaction { + partition_id, + reason, + skipped_at, + estimated_bytes, + limit_bytes, + num_files, + limit_num_files, + limit_num_files_first_in_partition, + } = skipped_compaction; + + Self { + partition_id: partition_id.get(), + reason, + skipped_at: skipped_at.get(), + estimated_bytes, + limit_bytes, + num_files, + limit_num_files, + limit_num_files_first_in_partition, + } + } +} + +impl From + for SkippedCompaction +{ + fn from( + skipped_compaction: generated_types::influxdata::iox::skipped_compaction::v1::SkippedCompaction, + ) -> Self { + Self { + partition_id: PartitionId::new(skipped_compaction.partition_id), + reason: skipped_compaction.reason, + skipped_at: Timestamp::new(skipped_compaction.skipped_at), + estimated_bytes: skipped_compaction.estimated_bytes, + limit_bytes: skipped_compaction.limit_bytes, + num_files: skipped_compaction.num_files, + limit_num_files: skipped_compaction.limit_num_files, + limit_num_files_first_in_partition: skipped_compaction + .limit_num_files_first_in_partition, + } + } +} + +/// Data for a parquet file reference that has been inserted in the catalog. +#[derive(Debug, Clone, PartialEq, Eq, sqlx::FromRow)] +pub struct ParquetFile { + /// the id of the file in the catalog + pub id: ParquetFileId, + /// the namespace + pub namespace_id: NamespaceId, + /// the table + pub table_id: TableId, + /// the partition identifier + pub partition_id: PartitionId, + /// the optional partition hash id + pub partition_hash_id: Option, + /// the uuid used in the object store path for this file + pub object_store_id: ObjectStoreId, + /// the min timestamp of data in this file + pub min_time: Timestamp, + /// the max timestamp of data in this file + pub max_time: Timestamp, + /// When this file was marked for deletion + pub to_delete: Option, + /// file size in bytes + pub file_size_bytes: i64, + /// the number of rows of data in this file + pub row_count: i64, + /// The compaction level of the file. + /// + /// * 0 (`CompactionLevel::Initial`): represents a level-0 file that is persisted by an + /// Ingester. Partitions with level-0 files are usually hot/recent partitions. + /// * 1 (`CompactionLevel::FileOverlapped`): represents a level-1 file that is persisted by a + /// Compactor and potentially overlaps with other level-1 files. Partitions with level-1 + /// files are partitions with a lot of or/and large overlapped files that have to go + /// through many compaction cycles before they are fully compacted to non-overlapped + /// files. + /// * 2 (`CompactionLevel::FileNonOverlapped`): represents a level-1 file that is persisted by + /// a Compactor and does not overlap with other files except level 0 ones. Eventually, + /// cold partitions (partitions that no longer needs to get compacted) will only include + /// one or many level-1 files + pub compaction_level: CompactionLevel, + /// the creation time of the parquet file + pub created_at: Timestamp, + /// Set of columns within this parquet file. + /// + /// # Relation to Table-wide Column Set + /// Columns within this set may or may not be part of the table-wide schema. + /// + /// Columns that are NOT part of the table-wide schema must be ignored. It is likely that these + /// columns were originally part of the table but were later removed. + /// + /// # Column Types + /// Column types are identical to the table-wide types. + /// + /// # Column Order & Sort Key + /// The columns that are present in the table-wide schema are sorted according to the partition + /// sort key. The occur in the parquet file according to this order. + pub column_set: ColumnSet, + /// the max of created_at of all L0 files needed for file/chunk ordering for deduplication + pub max_l0_created_at: Timestamp, +} + +impl ParquetFile { + /// Create new file from given parameters and ID. + /// + /// [`to_delete`](Self::to_delete) will be set to `None`. + pub fn from_params(params: ParquetFileParams, id: ParquetFileId) -> Self { + Self { + id, + partition_id: params.partition_id, + partition_hash_id: params.partition_hash_id, + namespace_id: params.namespace_id, + table_id: params.table_id, + object_store_id: params.object_store_id, + min_time: params.min_time, + max_time: params.max_time, + to_delete: None, + file_size_bytes: params.file_size_bytes, + row_count: params.row_count, + compaction_level: params.compaction_level, + created_at: params.created_at, + column_set: params.column_set, + max_l0_created_at: params.max_l0_created_at, + } + } + + /// Estimate the memory consumption of this object and its contents + pub fn size(&self) -> usize { + let hash_id = self + .partition_hash_id + .as_ref() + .map(|x| x.size()) + .unwrap_or_default(); + + std::mem::size_of_val(self) + hash_id + self.column_set.size() + - std::mem::size_of_val(&self.column_set) + } + + /// Return true if the time range overlaps with the time range of the given file + pub fn overlaps(&self, other: &Self) -> bool { + self.min_time <= other.max_time && self.max_time >= other.min_time + } + + /// Return true if the time range of this file overlaps with the given time range + pub fn overlaps_time_range(&self, min_time: Timestamp, max_time: Timestamp) -> bool { + self.min_time <= max_time && self.max_time >= min_time + } + + /// Return true if the time range of this file overlaps with any of the given split times. + pub fn needs_split(&self, split_times: &Vec) -> bool { + for t in split_times { + // split time is the last timestamp on the "left" side of the split, if it equals + // the min time, one ns goes left, the rest goes right. + if self.min_time.get() <= *t && self.max_time.get() > *t { + return true; + } + } + false + } + + /// Return true if the time range of this file overlaps with any of the given file ranges + pub fn overlaps_ranges(&self, ranges: &Vec) -> bool { + for range in ranges { + if self.min_time.get() <= range.max && self.max_time.get() >= range.min { + return true; + } + } + false + } + + /// Temporary to aid incremental migration + pub fn transition_partition_id(&self) -> TransitionPartitionId { + TransitionPartitionId::from_parts(self.partition_id, self.partition_hash_id.clone()) + } +} + +impl From for generated_types::influxdata::iox::catalog::v1::ParquetFile { + fn from(v: ParquetFile) -> Self { + Self { + id: v.id.get(), + namespace_id: v.namespace_id.get(), + table_id: v.table_id.get(), + partition_id: v.partition_id.get(), + partition_hash_id: v + .partition_hash_id + .map(|x| x.as_bytes().to_vec()) + .unwrap_or_default(), + object_store_id: v.object_store_id.to_string(), + min_time: v.min_time.get(), + max_time: v.max_time.get(), + to_delete: v.to_delete.map(|v| v.get()), + file_size_bytes: v.file_size_bytes, + row_count: v.row_count, + compaction_level: v.compaction_level as i32, + created_at: v.created_at.get(), + column_set: v.column_set.iter().map(|v| v.get()).collect(), + max_l0_created_at: v.max_l0_created_at.get(), + } + } +} + +/// Errors deserialising a protobuf serialised [`ParquetFile`]. +#[derive(Debug, Error)] +pub enum ParquetFileProtoError { + /// The proto type does not contain a partition ID. + #[error("no partition id specified for parquet file")] + NoPartitionId, + + /// The specified partition ID is invalid. + #[error(transparent)] + InvalidPartitionId(#[from] PartitionIdProtoError), + + /// The specified object store UUID is invalid. + #[error("invalid object store ID: {0}")] + InvalidObjectStoreId(uuid::Error), + + /// The specified compaction level value is invalid. + #[error(transparent)] + InvalidCompactionLevel(#[from] CompactionLevelProtoError), +} + +/// Data for a parquet file to be inserted into the catalog. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ParquetFileParams { + /// the namespace + pub namespace_id: NamespaceId, + /// the table + pub table_id: TableId, + /// the partition identifier + pub partition_id: PartitionId, + /// the partition hash ID + pub partition_hash_id: Option, + /// the uuid used in the object store path for this file + pub object_store_id: ObjectStoreId, + /// the min timestamp of data in this file + pub min_time: Timestamp, + /// the max timestamp of data in this file + pub max_time: Timestamp, + /// file size in bytes + pub file_size_bytes: i64, + /// the number of rows of data in this file + pub row_count: i64, + /// the compaction level of the file + pub compaction_level: CompactionLevel, + /// the creation time of the parquet file + pub created_at: Timestamp, + /// columns in this file. + pub column_set: ColumnSet, + /// the max of created_at of all L0 files + pub max_l0_created_at: Timestamp, +} + +/// ID of a chunk. +/// +/// This ID is unique within a single partition. +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ChunkId(Uuid); + +impl ChunkId { + /// Create new, random ID. + #[allow(clippy::new_without_default)] // `new` creates non-deterministic result + pub fn new() -> Self { + Self(Uuid::new_v4()) + } + + /// **TESTING ONLY:** Create new ID from integer. + /// + /// Since this can easily lead to ID collisions (which in turn can lead to panics), this must + /// only be used for testing purposes! + pub fn new_test(id: u128) -> Self { + Self(Uuid::from_u128(id)) + } + + /// The chunk id is only effective in case the chunk's order is the same with another chunk. + /// Hence collisions are safe in that context. + pub fn new_id(id: u128) -> Self { + Self(Uuid::from_u128(id)) + } + + /// Get inner UUID. + pub fn get(&self) -> Uuid { + self.0 + } +} + +impl std::fmt::Debug for ChunkId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + ::fmt(self, f) + } +} + +impl std::fmt::Display for ChunkId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if (self.0.get_variant() == uuid::Variant::RFC4122) + && (self.0.get_version() == Some(uuid::Version::Random)) + { + f.debug_tuple("ChunkId").field(&self.0).finish() + } else { + f.debug_tuple("ChunkId").field(&self.0.as_u128()).finish() + } + } +} + +impl From for ChunkId { + fn from(id: ObjectStoreId) -> Self { + Self(id.get_uuid()) + } +} + +/// Order of a chunk. +/// +/// This is used for: +/// 1. **upsert order:** chunks with higher order overwrite data in chunks with lower order +/// 2. **locking order:** chunks must be locked in consistent (ascending) order +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ChunkOrder(i64); + +impl ChunkOrder { + /// The minimum ordering value a chunk could have. Currently only used in testing. + pub const MIN: Self = Self(0); + + /// The maximum chunk order. + pub const MAX: Self = Self(i64::MAX); + + /// Create a ChunkOrder from the given value. + pub fn new(order: i64) -> Self { + Self(order) + } + + /// Under underlying order as integer. + pub fn get(&self) -> i64 { + self.0 + } +} + +/// Represents a parsed delete predicate for evaluation by the InfluxDB IOx +/// query engine. +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct DeletePredicate { + /// Only rows within this range are included in + /// results. Other rows are excluded. + pub range: TimestampRange, + + /// Optional arbitrary predicates, represented as list of + /// expressions applied a logical conjunction (aka they + /// are 'AND'ed together). Only rows that evaluate to TRUE for all + /// these expressions should be returned. Other rows are excluded + /// from the results. + pub exprs: Vec, +} + +impl DeletePredicate { + /// Format expr to SQL string. + pub fn expr_sql_string(&self) -> String { + let mut out = String::new(); + for expr in &self.exprs { + if !out.is_empty() { + write!(&mut out, " AND ").expect("writing to a string shouldn't fail"); + } + write!(&mut out, "{expr}").expect("writing to a string shouldn't fail"); + } + out + } + + /// Return the approximate memory size of the predicate, in bytes. + /// + /// This includes `Self`. + pub fn size(&self) -> usize { + std::mem::size_of::() + self.exprs.iter().map(|expr| expr.size()).sum::() + } + + /// Return the delete predicate for data outside retention + /// We need to only retain time >= retention_time. + /// Thus we only need to set the range to MIN < time < retention_time + pub fn retention_delete_predicate(retention_time: i64) -> Self { + let range = TimestampRange { + start: i64::MIN, + end: retention_time, + }; + Self { + range, + exprs: vec![], + } + } +} + +/// Single expression to be used as parts of a predicate. +/// +/// Only very simple expression of the type ` ` are supported. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct DeleteExpr { + /// Column (w/o table name). + pub column: String, + + /// Operator. + pub op: Op, + + /// Scalar value. + pub scalar: Scalar, +} + +impl DeleteExpr { + /// Create a new [`DeleteExpr`] + pub fn new(column: String, op: Op, scalar: Scalar) -> Self { + Self { column, op, scalar } + } + + /// Column (w/o table name). + pub fn column(&self) -> &str { + &self.column + } + + /// Operator. + pub fn op(&self) -> Op { + self.op + } + + /// Scalar value. + pub fn scalar(&self) -> &Scalar { + &self.scalar + } + + /// Return the approximate memory size of the expression, in bytes. + /// + /// This includes `Self`. + pub fn size(&self) -> usize { + std::mem::size_of::() + self.column.capacity() + self.scalar.size() + } +} + +impl std::fmt::Display for DeleteExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + r#""{}"{}{}"#, + self.column().replace('\\', r"\\").replace('"', r#"\""#), + self.op(), + self.scalar(), + ) + } +} + +/// Binary operator that can be evaluated on a column and a scalar value. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Op { + /// Strict equality (`=`). + Eq, + + /// Inequality (`!=`). + Ne, +} + +impl std::fmt::Display for Op { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Eq => write!(f, "="), + Self::Ne => write!(f, "!="), + } + } +} + +/// Scalar value of a certain type. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[allow(missing_docs)] +pub enum Scalar { + Bool(bool), + I64(i64), + F64(ordered_float::OrderedFloat), + String(String), +} + +impl Scalar { + /// Return the approximate memory size of the scalar, in bytes. + /// + /// This includes `Self`. + pub fn size(&self) -> usize { + std::mem::size_of::() + + match &self { + Self::Bool(_) | Self::I64(_) | Self::F64(_) => 0, + Self::String(s) => s.capacity(), + } + } +} + +impl std::fmt::Display for Scalar { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Scalar::Bool(value) => value.fmt(f), + Scalar::I64(value) => value.fmt(f), + Scalar::F64(value) => match value.classify() { + FpCategory::Nan => write!(f, "'NaN'"), + FpCategory::Infinite if *value.as_ref() < 0.0 => write!(f, "'-Infinity'"), + FpCategory::Infinite => write!(f, "'Infinity'"), + _ => write!(f, "{:?}", value.as_ref()), + }, + Scalar::String(value) => { + write!(f, "'{}'", value.replace('\\', r"\\").replace('\'', r"\'")) + } + } + } +} + +/// A string that cannot be empty +/// +/// This is particularly useful for types that map to/from protobuf, where string fields +/// are not nullable - that is they default to an empty string if not specified +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct NonEmptyString(Box); + +impl NonEmptyString { + /// Create a new `NonEmptyString` from the provided `String` + /// + /// Returns None if empty + pub fn new(s: impl Into) -> Option { + let s = s.into(); + match s.is_empty() { + true => None, + false => Some(Self(s.into_boxed_str())), + } + } +} + +impl Deref for NonEmptyString { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +/// Column name, statistics which encode type information +#[derive(Debug, PartialEq, Clone)] +pub struct ColumnSummary { + /// Column name + pub name: String, + + /// Column's Influx data model type + pub influxdb_type: InfluxDbType, + + /// Per column + pub stats: Statistics, +} + +impl ColumnSummary { + /// Returns the total number of rows (including nulls) in this column + pub fn total_count(&self) -> u64 { + self.stats.total_count() + } + + /// Updates statistics from other if the same type, otherwise a noop + pub fn update_from(&mut self, other: &Self) { + match (&mut self.stats, &other.stats) { + (Statistics::F64(s), Statistics::F64(o)) => { + s.update_from(o); + } + (Statistics::I64(s), Statistics::I64(o)) => { + s.update_from(o); + } + (Statistics::Bool(s), Statistics::Bool(o)) => { + s.update_from(o); + } + (Statistics::String(s), Statistics::String(o)) => { + s.update_from(o); + } + (Statistics::U64(s), Statistics::U64(o)) => { + s.update_from(o); + } + // do catch alls for the specific types, that way if a new type gets added, the compiler + // will complain. + (Statistics::F64(_), _) => unreachable!(), + (Statistics::I64(_), _) => unreachable!(), + (Statistics::U64(_), _) => unreachable!(), + (Statistics::Bool(_), _) => unreachable!(), + (Statistics::String(_), _) => unreachable!(), + } + } + + /// Updates these statistics so that that the total length of this + /// column is `len` rows, padding it with trailing NULLs if + /// necessary + pub fn update_to_total_count(&mut self, len: u64) { + let total_count = self.total_count(); + assert!( + total_count <= len, + "trying to shrink column stats from {total_count} to {len}" + ); + let delta = len - total_count; + self.stats.update_for_nulls(delta); + } + + /// Return size in bytes of this Column metadata (not the underlying column) + pub fn size(&self) -> usize { + mem::size_of::() + self.name.len() + self.stats.size() + } +} + +// Replicate this enum here as it can't be derived from the existing statistics +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[allow(missing_docs)] +pub enum InfluxDbType { + Tag, + Field, + Timestamp, +} + +/// Summary statistics for a column. +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct StatValues { + /// minimum (non-NaN, non-NULL) value, if any + pub min: Option, + + /// maximum (non-NaN, non-NULL) value, if any + pub max: Option, + + /// total number of values in this column, including null values + pub total_count: u64, + + /// number of null values in this column + pub null_count: Option, + + /// number of distinct values in this column if known + /// + /// This includes NULLs and NANs + pub distinct_count: Option, +} + +/// Represents the result of comparing the min/max ranges of two [`StatValues`] +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum StatOverlap { + /// There is at least one value that exists in both ranges + NonZero, + + /// There are zero values that exists in both ranges + Zero, + + /// It is not known if there are any intersections (e.g. because + /// one of the bounds is not Known / is None) + Unknown, +} + +impl StatValues +where + T: PartialOrd, +{ + /// returns information about the overlap between two `StatValues` + pub fn overlaps(&self, other: &Self) -> StatOverlap { + match (&self.min, &self.max, &other.min, &other.max) { + (Some(self_min), Some(self_max), Some(other_min), Some(other_max)) => { + if self_min <= other_max && self_max >= other_min { + StatOverlap::NonZero + } else { + StatOverlap::Zero + } + } + // At least one of the values was None + _ => StatOverlap::Unknown, + } + } +} + +impl Default for StatValues { + fn default() -> Self { + Self { + min: None, + max: None, + total_count: 0, + null_count: None, + distinct_count: None, + } + } +} + +impl StatValues { + /// Create new statistics with no values + pub fn new_empty() -> Self { + Self { + min: None, + max: None, + total_count: 0, + null_count: Some(0), + distinct_count: None, + } + } + + /// Returns true if both the min and max values are None (aka not known) + pub fn is_none(&self) -> bool { + self.min.is_none() && self.max.is_none() + } + + /// Update the statistics values to account for `num_nulls` additional null values + pub fn update_for_nulls(&mut self, num_nulls: u64) { + self.total_count += num_nulls; + self.null_count = self.null_count.map(|x| x + num_nulls); + } + + /// updates the statistics keeping the min, max and incrementing count. + /// + /// The type plumbing exists to allow calling with `&str` on a `StatValues`. + pub fn update(&mut self, other: &U) + where + T: Borrow, + U: ToOwned + PartialOrd + IsNan, + { + self.total_count += 1; + self.distinct_count = None; + + if !other.is_nan() { + match &self.min { + None => self.min = Some(other.to_owned()), + Some(s) => { + if s.borrow() > other { + self.min = Some(other.to_owned()); + } + } + } + + match &self.max { + None => { + self.max = Some(other.to_owned()); + } + Some(s) => { + if other > s.borrow() { + self.max = Some(other.to_owned()); + } + } + } + } + } +} + +impl StatValues +where + T: Clone + PartialOrd, +{ + /// Updates statistics from other + pub fn update_from(&mut self, other: &Self) { + self.total_count += other.total_count; + self.null_count = self.null_count.zip(other.null_count).map(|(a, b)| a + b); + + // No way to accurately aggregate counts + self.distinct_count = None; + + match (&self.min, &other.min) { + (None, None) | (Some(_), None) => {} + (None, Some(o)) => self.min = Some(o.clone()), + (Some(s), Some(o)) => { + if s > o { + self.min = Some(o.clone()); + } + } + } + + match (&self.max, &other.max) { + (None, None) | (Some(_), None) => {} + (None, Some(o)) => self.max = Some(o.clone()), + (Some(s), Some(o)) => { + if o > s { + self.max = Some(o.clone()); + } + } + }; + } +} + +impl StatValues +where + T: IsNan + PartialOrd, +{ + /// Create new statistics with the specified count and null count + pub fn new(min: Option, max: Option, total_count: u64, null_count: Option) -> Self { + let distinct_count = None; + Self::new_with_distinct(min, max, total_count, null_count, distinct_count) + } + + /// Create statistics for a column that only has nulls up to now + pub fn new_all_null(total_count: u64, distinct_count: Option) -> Self { + let min = None; + let max = None; + let null_count = Some(total_count); + + if let Some(count) = distinct_count { + assert!(count > 0); + } + Self::new_with_distinct( + min, + max, + total_count, + null_count, + distinct_count.map(|c| NonZeroU64::new(c).unwrap()), + ) + } + + /// Create statistics for a column with zero nulls and unknown distinct count + pub fn new_non_null(min: Option, max: Option, total_count: u64) -> Self { + let null_count = Some(0); + let distinct_count = None; + Self::new_with_distinct(min, max, total_count, null_count, distinct_count) + } + + /// Create new statistics with the specified count and null count and distinct values + pub fn new_with_distinct( + min: Option, + max: Option, + total_count: u64, + null_count: Option, + distinct_count: Option, + ) -> Self { + if let Some(min) = &min { + assert!(!min.is_nan()); + } + if let Some(max) = &max { + assert!(!max.is_nan()); + } + if let (Some(min), Some(max)) = (&min, &max) { + assert!(min <= max); + } + + Self { + min, + max, + total_count, + null_count, + distinct_count, + } + } +} + +/// Whether a type is NaN or not. +pub trait IsNan { + /// Test for NaNess. + fn is_nan(&self) -> bool; +} + +impl IsNan for &T { + fn is_nan(&self) -> bool { + (*self).is_nan() + } +} + +macro_rules! impl_is_nan_false { + ($t:ty) => { + impl IsNan for $t { + fn is_nan(&self) -> bool { + false + } + } + }; +} + +impl_is_nan_false!(bool); +impl_is_nan_false!(str); +impl_is_nan_false!(String); +impl_is_nan_false!(i8); +impl_is_nan_false!(i16); +impl_is_nan_false!(i32); +impl_is_nan_false!(i64); +impl_is_nan_false!(u8); +impl_is_nan_false!(u16); +impl_is_nan_false!(u32); +impl_is_nan_false!(u64); + +impl IsNan for f64 { + fn is_nan(&self) -> bool { + Self::is_nan(*self) + } +} + +/// Statistics and type information for a column. +#[derive(Debug, PartialEq, Clone)] +#[allow(missing_docs)] +pub enum Statistics { + I64(StatValues), + U64(StatValues), + Bool(StatValues), + String(StatValues), + + /// For the purposes of min/max values of floats, NaN values are ignored (no + /// ordering is applied to NaNs). + F64(StatValues), +} + +impl Statistics { + /// Returns the total number of rows in this column + pub fn total_count(&self) -> u64 { + match self { + Self::I64(s) => s.total_count, + Self::U64(s) => s.total_count, + Self::F64(s) => s.total_count, + Self::Bool(s) => s.total_count, + Self::String(s) => s.total_count, + } + } + + /// Returns true if both the min and max values are None (aka not known) + pub fn is_none(&self) -> bool { + match self { + Self::I64(v) => v.is_none(), + Self::U64(v) => v.is_none(), + Self::F64(v) => v.is_none(), + Self::Bool(v) => v.is_none(), + Self::String(v) => v.is_none(), + } + } + + /// Returns the number of null rows in this column + pub fn null_count(&self) -> Option { + match self { + Self::I64(s) => s.null_count, + Self::U64(s) => s.null_count, + Self::F64(s) => s.null_count, + Self::Bool(s) => s.null_count, + Self::String(s) => s.null_count, + } + } + + /// Returns the distinct count if known + pub fn distinct_count(&self) -> Option { + match self { + Self::I64(s) => s.distinct_count, + Self::U64(s) => s.distinct_count, + Self::F64(s) => s.distinct_count, + Self::Bool(s) => s.distinct_count, + Self::String(s) => s.distinct_count, + } + } + + /// Update the statistics values to account for `num_nulls` additional null values + pub fn update_for_nulls(&mut self, num_nulls: u64) { + match self { + Self::I64(v) => v.update_for_nulls(num_nulls), + Self::U64(v) => v.update_for_nulls(num_nulls), + Self::F64(v) => v.update_for_nulls(num_nulls), + Self::Bool(v) => v.update_for_nulls(num_nulls), + Self::String(v) => v.update_for_nulls(num_nulls), + } + } + + /// Return the size in bytes of this stats instance + pub fn size(&self) -> usize { + match self { + Self::String(v) => std::mem::size_of::() + v.string_size(), + _ => std::mem::size_of::(), + } + } + + /// Return a human interpretable description of this type + pub fn type_name(&self) -> &'static str { + match self { + Self::I64(_) => "I64", + Self::U64(_) => "U64", + Self::F64(_) => "F64", + Self::Bool(_) => "Bool", + Self::String(_) => "String", + } + } + + /// Extract i64 type. + pub fn as_i64(&self) -> Option<&StatValues> { + match self { + Self::I64(val) => Some(val), + _ => None, + } + } +} + +impl StatValues { + /// Returns the bytes associated by storing min/max string values + pub fn string_size(&self) -> usize { + self.min.as_ref().map(|x| x.len()).unwrap_or(0) + + self.max.as_ref().map(|x| x.len()).unwrap_or(0) + } +} + +/// Metadata and statistics information for a table. This can be +/// either for the portion of a Table stored within a single chunk or +/// aggregated across chunks. +#[derive(Debug, PartialEq, Clone, Default)] +pub struct TableSummary { + /// Per column statistics + pub columns: Vec, +} + +impl TableSummary { + /// Get the column summary by name. + pub fn column(&self, name: &str) -> Option<&ColumnSummary> { + self.columns.iter().find(|c| c.name == name) + } + + /// Returns the total number of rows in the columns of this summary + pub fn total_count(&self) -> u64 { + // Assumes that all tables have the same number of rows, so + // pick the first one + let count = self.columns.first().map(|c| c.total_count()).unwrap_or(0); + + // Validate that the counts are consistent across columns + for c in &self.columns { + // Restore to assert when https://github.com/influxdata/influxdb_iox/issues/2124 is fixed + if c.total_count() != count { + warn!(column_name=%c.name, + column_count=c.total_count(), previous_count=count, + "Mismatch in statistics count, see #2124"); + } + } + count + } + + /// Updates the table summary with combined stats from the other. Counts are + /// treated as non-overlapping so they're just added together. If the + /// type of a column differs between the two tables, no update is done + /// on that column. Columns that only exist in the other are cloned into + /// this table summary. + pub fn update_from(&mut self, other: &Self) { + let new_total_count = self.total_count() + other.total_count(); + + // update all existing columns + for col in &mut self.columns { + if let Some(other_col) = other.column(&col.name) { + col.update_from(other_col); + } else { + col.update_to_total_count(new_total_count); + } + } + + // Add any columns that were new + for col in &other.columns { + if self.column(&col.name).is_none() { + let mut new_col = col.clone(); + // ensure the count is consistent + new_col.update_to_total_count(new_total_count); + self.columns.push(new_col); + } + } + } + + /// Total size of all ColumnSummaries that belong to this table which include + /// column names and their stats + pub fn size(&self) -> usize { + let size: usize = self.columns.iter().map(|c| c.size()).sum(); + size + mem::size_of::() // Add size of this struct that points to + // table and ColumnSummary + } + + /// Extracts min/max values of the timestamp column, if possible + pub fn time_range(&self) -> Option { + self.column(TIME_COLUMN_NAME).and_then(|c| { + if let Statistics::I64(StatValues { + min: Some(min), + max: Some(max), + .. + }) = &c.stats + { + Some(TimestampMinMax::new(*min, *max)) + } else { + None + } + }) + } +} + +/// minimum time that can be represented. +/// +/// 1677-09-21 00:12:43.145224194 +0000 UTC +/// +/// The two lowest minimum integers are used as sentinel values. The +/// minimum value needs to be used as a value lower than any other value for +/// comparisons and another separate value is needed to act as a sentinel +/// default value that is unusable by the user, but usable internally. +/// Because these two values need to be used for a special purpose, we do +/// not allow users to write points at these two times. +/// +/// Source: [influxdb](https://github.com/influxdata/influxdb/blob/540bb66e1381a48a6d1ede4fc3e49c75a7d9f4af/models/time.go#L12-L34) +pub const MIN_NANO_TIME: i64 = i64::MIN + 2; + +/// maximum time that can be represented. +/// +/// 2262-04-11 23:47:16.854775806 +0000 UTC +/// +/// The highest time represented by a nanosecond needs to be used for an exclusive range, so the +/// maximum time needs to be one less than the possible maximum number of nanoseconds representable +/// by an int64 so that we don't lose a point at that one time. +/// Source: [influxdb](https://github.com/influxdata/influxdb/blob/540bb66e1381a48a6d1ede4fc3e49c75a7d9f4af/models/time.go#L12-L34) +pub const MAX_NANO_TIME: i64 = i64::MAX - 1; + +/// Specifies a continuous range of nanosecond timestamps. Timestamp +/// predicates are so common and critical to performance of timeseries +/// databases in general, and IOx in particular, that they are handled +/// specially +/// +/// Timestamp ranges are defined such that a value `v` is within the +/// range iff: +/// +/// ```text +/// range.start <= v < range.end +/// ``` +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Copy, Debug, Hash)] +pub struct TimestampRange { + /// Start defines the inclusive lower bound. Minimum value is [MIN_NANO_TIME] + start: i64, + /// End defines the exclusive upper bound. Maximum value is [MAX_NANO_TIME] + end: i64, +} + +impl TimestampRange { + /// Create a new TimestampRange. + /// + /// Takes an inclusive start and an exclusive end. You may create an empty range by setting `start = end`. + /// + /// Clamps `start` to [`MIN_NANO_TIME`]. + /// end is unclamped. End may be set to `i64:MAX == MAX_NANO_TIME+1` to indicate that the upper bound is NOT + /// restricted (this does NOT affect `start` though!). + /// + /// If `start > end`, this will be interpreted as an empty time range and `start` will be set to `end`. + pub fn new(start: i64, end: i64) -> Self { + let start = start.clamp(MIN_NANO_TIME, end); + let end = end.max(MIN_NANO_TIME); + Self { start, end } + } + + /// Returns true if this range contains all representable timestamps + pub fn contains_all(&self) -> bool { + self.start <= MIN_NANO_TIME && self.end > MAX_NANO_TIME + } + + /// Returns true if this range contains all representable timestamps except possibly MAX_NANO_TIME + /// + /// This is required for queries from InfluxQL, which are intended to be + /// for all time but instead can be for [MIN_NANO_TIME, MAX_NANO_TIME). + /// When is fixed, + /// all uses of contains_nearly_all should be replaced by contains_all + pub fn contains_nearly_all(&self) -> bool { + self.start <= MIN_NANO_TIME && self.end >= MAX_NANO_TIME + } + + #[inline] + /// Returns true if this range contains the value v + pub fn contains(&self, v: i64) -> bool { + self.start <= v && v < self.end + } + + /// Return the timestamp exclusive range's end. + pub fn end(&self) -> i64 { + self.end + } + + /// Return the timestamp inclusive range's start. + pub fn start(&self) -> i64 { + self.start + } +} + +/// Specifies a min/max timestamp value. +/// +/// Note this differs subtlety (but critically) from a +/// [`TimestampRange`] as the minimum and maximum values are included ([`TimestampRange`] has an exclusive end). +#[derive(Clone, Debug, Copy, PartialEq, Eq)] +pub struct TimestampMinMax { + /// The minimum timestamp value + pub min: i64, + /// the maximum timestamp value + pub max: i64, +} + +impl TimestampMinMax { + /// Create a new TimestampMinMax. Panics if min > max. + pub fn new(min: i64, max: i64) -> Self { + assert!(min <= max, "expected min ({min}) <= max ({max})"); + Self { min, max } + } + + #[inline] + /// Returns true if any of the values between min / max + /// (inclusive) are contained within the specified timestamp range + pub fn overlaps(&self, range: TimestampRange) -> bool { + range.contains(self.min) + || range.contains(self.max) + || (self.min <= range.start && self.max >= range.end) + } + + /// Returns the union of this range with `other` with the minimum of the `min`s + /// and the maximum of the `max`es + + pub fn union(&self, other: &Self) -> Self { + Self { + min: self.min.min(other.min), + max: self.max.max(other.max), + } + } +} + +/// FileRange describes a range of files by the min/max time and the sum of their capacities. +#[derive(Clone, Debug, Copy, PartialEq, Eq)] +pub struct FileRange { + /// The minimum time of any file in the range + pub min: i64, + /// The maximum time of any file in the range + pub max: i64, + /// The sum of the sizes of all files in the range + pub cap: usize, +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::borrow::Cow; + + use ordered_float::OrderedFloat; + + #[test] + fn test_chunk_id_new() { + // `ChunkId::new()` create new random ID + assert_ne!(ChunkId::new(), ChunkId::new()); + } + + #[test] + fn test_chunk_id_new_test() { + // `ChunkId::new_test(...)` creates deterministic ID + assert_eq!(ChunkId::new_test(1), ChunkId::new_test(1)); + assert_ne!(ChunkId::new_test(1), ChunkId::new_test(2)); + } + + #[test] + fn test_chunk_id_debug_and_display() { + // Random chunk IDs use UUID-format + let id_random = ChunkId::new(); + let inner: Uuid = id_random.get(); + assert_eq!(format!("{id_random:?}"), format!("ChunkId({inner})")); + assert_eq!(format!("{id_random}"), format!("ChunkId({inner})")); + + // Deterministic IDs use integer format + let id_test = ChunkId::new_test(42); + assert_eq!(format!("{id_test:?}"), "ChunkId(42)"); + assert_eq!(format!("{id_test}"), "ChunkId(42)"); + } + + #[test] + fn test_expr_to_sql_no_expressions() { + let pred = DeletePredicate { + range: TimestampRange::new(1, 2), + exprs: vec![], + }; + assert_eq!(&pred.expr_sql_string(), ""); + } + + #[test] + fn test_expr_to_sql_operators() { + let pred = DeletePredicate { + range: TimestampRange::new(1, 2), + exprs: vec![ + DeleteExpr { + column: String::from("col1"), + op: Op::Eq, + scalar: Scalar::I64(1), + }, + DeleteExpr { + column: String::from("col2"), + op: Op::Ne, + scalar: Scalar::I64(2), + }, + ], + }; + assert_eq!(&pred.expr_sql_string(), r#""col1"=1 AND "col2"!=2"#); + } + + #[test] + fn test_expr_to_sql_column_escape() { + let pred = DeletePredicate { + range: TimestampRange::new(1, 2), + exprs: vec![ + DeleteExpr { + column: String::from("col 1"), + op: Op::Eq, + scalar: Scalar::I64(1), + }, + DeleteExpr { + column: String::from(r"col\2"), + op: Op::Eq, + scalar: Scalar::I64(2), + }, + DeleteExpr { + column: String::from(r#"col"3"#), + op: Op::Eq, + scalar: Scalar::I64(3), + }, + ], + }; + assert_eq!( + &pred.expr_sql_string(), + r#""col 1"=1 AND "col\\2"=2 AND "col\"3"=3"# + ); + } + + #[test] + fn test_expr_to_sql_bool() { + let pred = DeletePredicate { + range: TimestampRange::new(1, 2), + exprs: vec![ + DeleteExpr { + column: String::from("col1"), + op: Op::Eq, + scalar: Scalar::Bool(false), + }, + DeleteExpr { + column: String::from("col2"), + op: Op::Eq, + scalar: Scalar::Bool(true), + }, + ], + }; + assert_eq!(&pred.expr_sql_string(), r#""col1"=false AND "col2"=true"#); + } + + #[test] + fn test_expr_to_sql_i64() { + let pred = DeletePredicate { + range: TimestampRange::new(1, 2), + exprs: vec![ + DeleteExpr { + column: String::from("col1"), + op: Op::Eq, + scalar: Scalar::I64(0), + }, + DeleteExpr { + column: String::from("col2"), + op: Op::Eq, + scalar: Scalar::I64(-1), + }, + DeleteExpr { + column: String::from("col3"), + op: Op::Eq, + scalar: Scalar::I64(1), + }, + DeleteExpr { + column: String::from("col4"), + op: Op::Eq, + scalar: Scalar::I64(i64::MIN), + }, + DeleteExpr { + column: String::from("col5"), + op: Op::Eq, + scalar: Scalar::I64(i64::MAX), + }, + ], + }; + assert_eq!( + &pred.expr_sql_string(), + r#""col1"=0 AND "col2"=-1 AND "col3"=1 AND "col4"=-9223372036854775808 AND "col5"=9223372036854775807"# + ); + } + + #[test] + fn test_expr_to_sql_f64() { + let pred = DeletePredicate { + range: TimestampRange::new(1, 2), + exprs: vec![ + DeleteExpr { + column: String::from("col1"), + op: Op::Eq, + scalar: Scalar::F64(OrderedFloat::from(0.0)), + }, + DeleteExpr { + column: String::from("col2"), + op: Op::Eq, + scalar: Scalar::F64(OrderedFloat::from(-0.0)), + }, + DeleteExpr { + column: String::from("col3"), + op: Op::Eq, + scalar: Scalar::F64(OrderedFloat::from(1.0)), + }, + DeleteExpr { + column: String::from("col4"), + op: Op::Eq, + scalar: Scalar::F64(OrderedFloat::from(f64::INFINITY)), + }, + DeleteExpr { + column: String::from("col5"), + op: Op::Eq, + scalar: Scalar::F64(OrderedFloat::from(f64::NEG_INFINITY)), + }, + DeleteExpr { + column: String::from("col6"), + op: Op::Eq, + scalar: Scalar::F64(OrderedFloat::from(f64::NAN)), + }, + ], + }; + assert_eq!( + &pred.expr_sql_string(), + r#""col1"=0.0 AND "col2"=-0.0 AND "col3"=1.0 AND "col4"='Infinity' AND "col5"='-Infinity' AND "col6"='NaN'"# + ); + } + + #[test] + fn test_expr_to_sql_string() { + let pred = DeletePredicate { + range: TimestampRange::new(1, 2), + exprs: vec![ + DeleteExpr { + column: String::from("col1"), + op: Op::Eq, + scalar: Scalar::String(String::from("")), + }, + DeleteExpr { + column: String::from("col2"), + op: Op::Eq, + scalar: Scalar::String(String::from("foo")), + }, + DeleteExpr { + column: String::from("col3"), + op: Op::Eq, + scalar: Scalar::String(String::from(r"fo\o")), + }, + DeleteExpr { + column: String::from("col4"), + op: Op::Eq, + scalar: Scalar::String(String::from(r#"fo'o"#)), + }, + ], + }; + assert_eq!( + &pred.expr_sql_string(), + r#""col1"='' AND "col2"='foo' AND "col3"='fo\\o' AND "col4"='fo\'o'"# + ); + } + + #[test] + fn statistics_new_non_null() { + let actual = StatValues::new_non_null(Some(-1i64), Some(1i64), 3); + let expected = StatValues { + min: Some(-1i64), + max: Some(1i64), + total_count: 3, + null_count: Some(0), + distinct_count: None, + }; + assert_eq!(actual, expected); + } + + #[test] + fn statistics_new_all_null() { + // i64 values do not have a distinct count + let actual = StatValues::::new_all_null(3, None); + let expected = StatValues { + min: None, + max: None, + total_count: 3, + null_count: Some(3), + distinct_count: None, + }; + assert_eq!(actual, expected); + + // string columns can have a distinct count + let actual = StatValues::::new_all_null(3, Some(1_u64)); + let expected = StatValues { + min: None, + max: None, + total_count: 3, + null_count: Some(3), + distinct_count: Some(NonZeroU64::try_from(1_u64).unwrap()), + }; + assert_eq!(actual, expected); + } + + impl StatValues + where + T: IsNan + PartialOrd + Clone, + { + fn new_with_value(starting_value: T) -> Self { + let starting_value = if starting_value.is_nan() { + None + } else { + Some(starting_value) + }; + + let min = starting_value.clone(); + let max = starting_value; + let total_count = 1; + let null_count = Some(0); + let distinct_count = None; + Self::new_with_distinct(min, max, total_count, null_count, distinct_count) + } + } + + impl Statistics { + /// Return the minimum value, if any, formatted as a string + fn min_as_str(&self) -> Option> { + match self { + Self::I64(v) => v.min.map(|x| Cow::Owned(x.to_string())), + Self::U64(v) => v.min.map(|x| Cow::Owned(x.to_string())), + Self::F64(v) => v.min.map(|x| Cow::Owned(x.to_string())), + Self::Bool(v) => v.min.map(|x| Cow::Owned(x.to_string())), + Self::String(v) => v.min.as_deref().map(Cow::Borrowed), + } + } + + /// Return the maximum value, if any, formatted as a string + fn max_as_str(&self) -> Option> { + match self { + Self::I64(v) => v.max.map(|x| Cow::Owned(x.to_string())), + Self::U64(v) => v.max.map(|x| Cow::Owned(x.to_string())), + Self::F64(v) => v.max.map(|x| Cow::Owned(x.to_string())), + Self::Bool(v) => v.max.map(|x| Cow::Owned(x.to_string())), + Self::String(v) => v.max.as_deref().map(Cow::Borrowed), + } + } + } + + #[test] + fn statistics_update() { + let mut stat = StatValues::new_with_value(23); + assert_eq!(stat.min, Some(23)); + assert_eq!(stat.max, Some(23)); + assert_eq!(stat.total_count, 1); + + stat.update(&55); + assert_eq!(stat.min, Some(23)); + assert_eq!(stat.max, Some(55)); + assert_eq!(stat.total_count, 2); + + stat.update(&6); + assert_eq!(stat.min, Some(6)); + assert_eq!(stat.max, Some(55)); + assert_eq!(stat.total_count, 3); + + stat.update(&30); + assert_eq!(stat.min, Some(6)); + assert_eq!(stat.max, Some(55)); + assert_eq!(stat.total_count, 4); + } + + #[test] + fn statistics_default() { + let mut stat = StatValues::default(); + assert_eq!(stat.min, None); + assert_eq!(stat.max, None); + assert_eq!(stat.total_count, 0); + + stat.update(&55); + assert_eq!(stat.min, Some(55)); + assert_eq!(stat.max, Some(55)); + assert_eq!(stat.total_count, 1); + + let mut stat = StatValues::::default(); + assert_eq!(stat.min, None); + assert_eq!(stat.max, None); + assert_eq!(stat.total_count, 0); + + stat.update("cupcakes"); + assert_eq!(stat.min, Some("cupcakes".to_string())); + assert_eq!(stat.max, Some("cupcakes".to_string())); + assert_eq!(stat.total_count, 1); + + stat.update("woo"); + assert_eq!(stat.min, Some("cupcakes".to_string())); + assert_eq!(stat.max, Some("woo".to_string())); + assert_eq!(stat.total_count, 2); + } + + #[test] + fn statistics_is_none() { + let mut stat = StatValues::default(); + assert!(stat.is_none()); + stat.min = Some(0); + assert!(!stat.is_none()); + stat.max = Some(1); + assert!(!stat.is_none()); + } + + #[test] + fn statistics_overlaps() { + let stat1 = StatValues { + min: Some(10), + max: Some(20), + ..Default::default() + }; + assert_eq!(stat1.overlaps(&stat1), StatOverlap::NonZero); + + // [--stat1--] + // [--stat2--] + let stat2 = StatValues { + min: Some(5), + max: Some(15), + ..Default::default() + }; + assert_eq!(stat1.overlaps(&stat2), StatOverlap::NonZero); + assert_eq!(stat2.overlaps(&stat1), StatOverlap::NonZero); + + // [--stat1--] + // [--stat3--] + let stat3 = StatValues { + min: Some(15), + max: Some(25), + ..Default::default() + }; + assert_eq!(stat1.overlaps(&stat3), StatOverlap::NonZero); + assert_eq!(stat3.overlaps(&stat1), StatOverlap::NonZero); + + // [--stat1--] + // [--stat4--] + let stat4 = StatValues { + min: Some(25), + max: Some(35), + ..Default::default() + }; + assert_eq!(stat1.overlaps(&stat4), StatOverlap::Zero); + assert_eq!(stat4.overlaps(&stat1), StatOverlap::Zero); + + // [--stat1--] + // [--stat5--] + let stat5 = StatValues { + min: Some(0), + max: Some(5), + ..Default::default() + }; + assert_eq!(stat1.overlaps(&stat5), StatOverlap::Zero); + assert_eq!(stat5.overlaps(&stat1), StatOverlap::Zero); + } + + #[test] + fn statistics_overlaps_none() { + let stat1 = StatValues { + min: Some(10), + max: Some(20), + ..Default::default() + }; + + let stat2 = StatValues { + min: None, + max: Some(20), + ..Default::default() + }; + assert_eq!(stat1.overlaps(&stat2), StatOverlap::Unknown); + assert_eq!(stat2.overlaps(&stat1), StatOverlap::Unknown); + + let stat3 = StatValues { + min: Some(10), + max: None, + ..Default::default() + }; + assert_eq!(stat1.overlaps(&stat3), StatOverlap::Unknown); + assert_eq!(stat3.overlaps(&stat1), StatOverlap::Unknown); + + let stat4 = StatValues { + min: None, + max: None, + ..Default::default() + }; + assert_eq!(stat1.overlaps(&stat4), StatOverlap::Unknown); + assert_eq!(stat4.overlaps(&stat1), StatOverlap::Unknown); + } + + #[test] + fn statistics_overlaps_mixed_none() { + let stat1 = StatValues { + min: Some(10), + max: None, + ..Default::default() + }; + + let stat2 = StatValues { + min: None, + max: Some(5), + ..Default::default() + }; + assert_eq!(stat1.overlaps(&stat2), StatOverlap::Unknown); + assert_eq!(stat2.overlaps(&stat1), StatOverlap::Unknown); + } + + #[test] + fn update_string() { + let mut stat = StatValues::new_with_value("bbb".to_string()); + assert_eq!(stat.min, Some("bbb".to_string())); + assert_eq!(stat.max, Some("bbb".to_string())); + assert_eq!(stat.total_count, 1); + + stat.update("aaa"); + assert_eq!(stat.min, Some("aaa".to_string())); + assert_eq!(stat.max, Some("bbb".to_string())); + assert_eq!(stat.total_count, 2); + + stat.update("z"); + assert_eq!(stat.min, Some("aaa".to_string())); + assert_eq!(stat.max, Some("z".to_string())); + assert_eq!(stat.total_count, 3); + + stat.update("p"); + assert_eq!(stat.min, Some("aaa".to_string())); + assert_eq!(stat.max, Some("z".to_string())); + assert_eq!(stat.total_count, 4); + } + + #[test] + fn stats_is_none() { + let stat = Statistics::I64(StatValues::new_non_null(Some(-1), Some(100), 1)); + assert!(!stat.is_none()); + + let stat = Statistics::I64(StatValues::new_non_null(None, Some(100), 1)); + assert!(!stat.is_none()); + + let stat = Statistics::I64(StatValues::new_non_null(None, None, 0)); + assert!(stat.is_none()); + } + + #[test] + fn stats_as_str_i64() { + let stat = Statistics::I64(StatValues::new_non_null(Some(-1), Some(100), 1)); + assert_eq!(stat.min_as_str(), Some("-1".into())); + assert_eq!(stat.max_as_str(), Some("100".into())); + + let stat = Statistics::I64(StatValues::new_non_null(None, None, 1)); + assert_eq!(stat.min_as_str(), None); + assert_eq!(stat.max_as_str(), None); + } + + #[test] + fn stats_as_str_u64() { + let stat = Statistics::U64(StatValues::new_non_null(Some(1), Some(100), 1)); + assert_eq!(stat.min_as_str(), Some("1".into())); + assert_eq!(stat.max_as_str(), Some("100".into())); + + let stat = Statistics::U64(StatValues::new_non_null(None, None, 1)); + assert_eq!(stat.min_as_str(), None); + assert_eq!(stat.max_as_str(), None); + } + + #[test] + fn stats_as_str_f64() { + let stat = Statistics::F64(StatValues::new_non_null(Some(99.0), Some(101.0), 1)); + assert_eq!(stat.min_as_str(), Some("99".into())); + assert_eq!(stat.max_as_str(), Some("101".into())); + + let stat = Statistics::F64(StatValues::new_non_null(None, None, 1)); + assert_eq!(stat.min_as_str(), None); + assert_eq!(stat.max_as_str(), None); + } + + #[test] + fn stats_as_str_bool() { + let stat = Statistics::Bool(StatValues::new_non_null(Some(false), Some(true), 1)); + assert_eq!(stat.min_as_str(), Some("false".into())); + assert_eq!(stat.max_as_str(), Some("true".into())); + + let stat = Statistics::Bool(StatValues::new_non_null(None, None, 1)); + assert_eq!(stat.min_as_str(), None); + assert_eq!(stat.max_as_str(), None); + } + + #[test] + fn stats_as_str_str() { + let stat = Statistics::String(StatValues::new_non_null( + Some("a".to_string()), + Some("zz".to_string()), + 1, + )); + assert_eq!(stat.min_as_str(), Some("a".into())); + assert_eq!(stat.max_as_str(), Some("zz".into())); + + let stat = Statistics::String(StatValues::new_non_null(None, None, 1)); + assert_eq!(stat.min_as_str(), None); + assert_eq!(stat.max_as_str(), None); + } + + #[test] + fn table_update_from() { + let mut string_stats = StatValues::new_with_value("foo".to_string()); + string_stats.update("bar"); + let string_col = ColumnSummary { + name: "string".to_string(), + influxdb_type: InfluxDbType::Field, + stats: Statistics::String(string_stats), + }; + + let mut int_stats = StatValues::new_with_value(1); + int_stats.update(&5); + let int_col = ColumnSummary { + name: "int".to_string(), + influxdb_type: InfluxDbType::Field, + stats: Statistics::I64(int_stats), + }; + + let mut float_stats = StatValues::new_with_value(9.1); + float_stats.update(&1.3); + let float_col = ColumnSummary { + name: "float".to_string(), + influxdb_type: InfluxDbType::Field, + stats: Statistics::F64(float_stats), + }; + + let mut table_a = TableSummary { + columns: vec![string_col, int_col, float_col], + }; + + let mut string_stats = StatValues::new_with_value("aaa".to_string()); + string_stats.update("zzz"); + let string_col = ColumnSummary { + name: "string".to_string(), + influxdb_type: InfluxDbType::Field, + stats: Statistics::String(string_stats), + }; + + let mut int_stats = StatValues::new_with_value(3); + int_stats.update(&9); + let int_col = ColumnSummary { + name: "int".to_string(), + influxdb_type: InfluxDbType::Field, + stats: Statistics::I64(int_stats), + }; + + let mut table_b = TableSummary { + columns: vec![int_col, string_col], + }; + + // keep this to test joining the other way + let table_c = table_a.clone(); + + table_a.update_from(&table_b); + let col = table_a.column("string").unwrap(); + assert_eq!( + col.stats, + Statistics::String(StatValues::new_non_null( + Some("aaa".to_string()), + Some("zzz".to_string()), + 4, + )) + ); + + let col = table_a.column("int").unwrap(); + assert_eq!( + col.stats, + Statistics::I64(StatValues::new_non_null(Some(1), Some(9), 4)) + ); + + let col = table_a.column("float").unwrap(); + assert_eq!( + col.stats, + Statistics::F64(StatValues::new(Some(1.3), Some(9.1), 4, Some(2))) + ); + + table_b.update_from(&table_c); + let col = table_b.column("string").unwrap(); + assert_eq!( + col.stats, + Statistics::String(StatValues::new_non_null( + Some("aaa".to_string()), + Some("zzz".to_string()), + 4, + )) + ); + + let col = table_b.column("int").unwrap(); + assert_eq!( + col.stats, + Statistics::I64(StatValues::new_non_null(Some(1), Some(9), 4)) + ); + + let col = table_b.column("float").unwrap(); + assert_eq!( + col.stats, + Statistics::F64(StatValues::new(Some(1.3), Some(9.1), 4, Some(2))) + ); + } + + #[test] + fn table_update_from_new_column() { + let string_stats = StatValues::new_with_value("bar".to_string()); + let string_col = ColumnSummary { + name: "string".to_string(), + influxdb_type: InfluxDbType::Tag, + stats: Statistics::String(string_stats), + }; + + let int_stats = StatValues::new_with_value(5); + let int_col = ColumnSummary { + name: "int".to_string(), + influxdb_type: InfluxDbType::Field, + stats: Statistics::I64(int_stats), + }; + + // table summary that does not have the "string" col + let table1 = TableSummary { + columns: vec![int_col.clone()], + }; + + // table summary that has both columns + let table2 = TableSummary { + columns: vec![int_col, string_col], + }; + + // Statistics should be the same regardless of the order we update the stats + + let expected_string_stats = Statistics::String(StatValues::new( + Some("bar".to_string()), + Some("bar".to_string()), + 2, // total count is 2 even though did not appear in the update + Some(1), // 1 null + )); + + let expected_int_stats = Statistics::I64(StatValues::new( + Some(5), + Some(5), + 2, + Some(0), // no nulls + )); + + // update table 1 with table 2 + let mut table = table1.clone(); + table.update_from(&table2); + + assert_eq!( + &table.column("string").unwrap().stats, + &expected_string_stats + ); + + assert_eq!(&table.column("int").unwrap().stats, &expected_int_stats); + + // update table 2 with table 1 + let mut table = table2; + table.update_from(&table1); + + assert_eq!( + &table.column("string").unwrap().stats, + &expected_string_stats + ); + + assert_eq!(&table.column("int").unwrap().stats, &expected_int_stats); + } + + #[test] + fn column_update_from_boolean() { + let bool_false = ColumnSummary { + name: "b".to_string(), + influxdb_type: InfluxDbType::Field, + stats: Statistics::Bool(StatValues::new(Some(false), Some(false), 1, Some(1))), + }; + let bool_true = ColumnSummary { + name: "b".to_string(), + influxdb_type: InfluxDbType::Field, + stats: Statistics::Bool(StatValues::new(Some(true), Some(true), 1, Some(2))), + }; + + let expected_stats = Statistics::Bool(StatValues::new(Some(false), Some(true), 2, Some(3))); + + let mut b = bool_false.clone(); + b.update_from(&bool_true); + assert_eq!(b.stats, expected_stats); + + let mut b = bool_true; + b.update_from(&bool_false); + assert_eq!(b.stats, expected_stats); + } + + #[test] + fn column_update_from_u64() { + let mut min = ColumnSummary { + name: "foo".to_string(), + influxdb_type: InfluxDbType::Field, + stats: Statistics::U64(StatValues::new(Some(5), Some(23), 1, Some(1))), + }; + + let max = ColumnSummary { + name: "foo".to_string(), + influxdb_type: InfluxDbType::Field, + stats: Statistics::U64(StatValues::new(Some(6), Some(506), 43, Some(2))), + }; + + min.update_from(&max); + + let expected = Statistics::U64(StatValues::new(Some(5), Some(506), 44, Some(3))); + assert_eq!(min.stats, expected); + } + + #[test] + fn nans() { + let mut stat = StatValues::default(); + assert_eq!(stat.min, None); + assert_eq!(stat.max, None); + assert_eq!(stat.total_count, 0); + + stat.update(&f64::NAN); + assert_eq!(stat.min, None); + assert_eq!(stat.max, None); + assert_eq!(stat.total_count, 1); + + stat.update(&1.0); + assert_eq!(stat.min, Some(1.0)); + assert_eq!(stat.max, Some(1.0)); + assert_eq!(stat.total_count, 2); + + stat.update(&2.0); + assert_eq!(stat.min, Some(1.0)); + assert_eq!(stat.max, Some(2.0)); + assert_eq!(stat.total_count, 3); + + stat.update(&f64::INFINITY); + assert_eq!(stat.min, Some(1.0)); + assert_eq!(stat.max, Some(f64::INFINITY)); + assert_eq!(stat.total_count, 4); + + stat.update(&-1.0); + assert_eq!(stat.min, Some(-1.0)); + assert_eq!(stat.max, Some(f64::INFINITY)); + assert_eq!(stat.total_count, 5); + + // =========== + + let mut stat = StatValues::new_with_value(2.0); + stat.update(&f64::INFINITY); + assert_eq!(stat.min, Some(2.0)); + assert_eq!(stat.max, Some(f64::INFINITY)); + assert_eq!(stat.total_count, 2); + + stat.update(&f64::NAN); + assert_eq!(stat.min, Some(2.0)); + assert_eq!(stat.max, Some(f64::INFINITY)); + assert_eq!(stat.total_count, 3); + + // =========== + + let mut stat2 = StatValues::new_with_value(1.0); + stat2.update_from(&stat); + assert_eq!(stat2.min, Some(1.0)); + assert_eq!(stat.max, Some(f64::INFINITY)); + assert_eq!(stat2.total_count, 4); + + // =========== + + let stat2 = StatValues::new_with_value(1.0); + stat.update_from(&stat2); + assert_eq!(stat.min, Some(1.0)); + assert_eq!(stat.max, Some(f64::INFINITY)); + assert_eq!(stat.total_count, 4); + + // =========== + + let stat = StatValues::new_with_value(f64::NAN); + assert_eq!(stat.min, None); + assert_eq!(stat.max, None); + assert_eq!(stat.total_count, 1); + } + + #[test] + fn test_timestamp_nano_min_max() { + let cases = vec![ + ( + "MIN / MAX Nanos", + TimestampRange::new(MIN_NANO_TIME, MAX_NANO_TIME + 1), + ), + ("MIN/MAX i64", TimestampRange::new(i64::MIN, i64::MAX)), + ]; + + for (name, range) in cases { + println!("case: {name}"); + assert!(!range.contains(i64::MIN)); + assert!(!range.contains(i64::MIN + 1)); + assert!(range.contains(MIN_NANO_TIME)); + assert!(range.contains(MIN_NANO_TIME + 1)); + assert!(range.contains(MAX_NANO_TIME - 1)); + assert!(range.contains(MAX_NANO_TIME)); + assert!(!range.contains(i64::MAX)); + assert!(range.contains_all()); + assert!(range.contains_nearly_all()); + } + } + + #[test] + fn test_timestamp_i64_min_max_offset() { + let range = TimestampRange::new(MIN_NANO_TIME + 1, MAX_NANO_TIME - 1); + + assert!(!range.contains(i64::MIN)); + assert!(!range.contains(MIN_NANO_TIME)); + assert!(range.contains(MIN_NANO_TIME + 1)); + assert!(range.contains(MAX_NANO_TIME - 2)); + assert!(!range.contains(MAX_NANO_TIME - 1)); + assert!(!range.contains(MAX_NANO_TIME)); + assert!(!range.contains(i64::MAX)); + assert!(!range.contains_all()); + assert!(!range.contains_nearly_all()); + } + + #[test] + fn test_timestamp_i64_min_max_offset_max() { + let range = TimestampRange::new(MIN_NANO_TIME, MAX_NANO_TIME); + + assert!(!range.contains(i64::MIN)); + assert!(range.contains(MIN_NANO_TIME)); + assert!(range.contains(MIN_NANO_TIME + 1)); + assert!(range.contains(MAX_NANO_TIME - 1)); + assert!(!range.contains(MAX_NANO_TIME)); + assert!(!range.contains(i64::MAX)); + assert!(!range.contains_all()); + assert!(range.contains_nearly_all()); + } + + #[test] + fn test_timestamp_range_contains() { + let range = TimestampRange::new(100, 200); + assert!(!range.contains(99)); + assert!(range.contains(100)); + assert!(range.contains(101)); + assert!(range.contains(199)); + assert!(!range.contains(200)); + assert!(!range.contains(201)); + } + + #[test] + fn test_timestamp_range_overlaps() { + let range = TimestampRange::new(100, 200); + assert!(!TimestampMinMax::new(0, 99).overlaps(range)); + assert!(TimestampMinMax::new(0, 100).overlaps(range)); + assert!(TimestampMinMax::new(0, 101).overlaps(range)); + + assert!(TimestampMinMax::new(0, 200).overlaps(range)); + assert!(TimestampMinMax::new(0, 201).overlaps(range)); + assert!(TimestampMinMax::new(0, 300).overlaps(range)); + + assert!(TimestampMinMax::new(100, 101).overlaps(range)); + assert!(TimestampMinMax::new(100, 200).overlaps(range)); + assert!(TimestampMinMax::new(100, 201).overlaps(range)); + + assert!(TimestampMinMax::new(101, 101).overlaps(range)); + assert!(TimestampMinMax::new(101, 200).overlaps(range)); + assert!(TimestampMinMax::new(101, 201).overlaps(range)); + + assert!(!TimestampMinMax::new(200, 200).overlaps(range)); + assert!(!TimestampMinMax::new(200, 201).overlaps(range)); + + assert!(!TimestampMinMax::new(201, 300).overlaps(range)); + } + + #[test] + #[should_panic(expected = "expected min (2) <= max (1)")] + fn test_timestamp_min_max_invalid() { + TimestampMinMax::new(2, 1); + } + + #[test] + fn test_table_schema_size() { + let schema1 = TableSchema { + id: TableId::new(1), + partition_template: Default::default(), + columns: ColumnsByName::default(), + }; + let schema2 = TableSchema { + id: TableId::new(2), + partition_template: Default::default(), + columns: ColumnsByName::new([Column { + id: ColumnId::new(1), + table_id: TableId::new(2), + name: String::from("foo"), + column_type: ColumnType::Bool, + }]), + }; + assert!(schema1.size() < schema2.size()); + } + + #[test] + fn test_namespace_schema_size() { + let schema1 = NamespaceSchema { + id: NamespaceId::new(1), + tables: BTreeMap::from([]), + max_tables: MaxTables::try_from(42).unwrap(), + max_columns_per_table: MaxColumnsPerTable::try_from(4).unwrap(), + retention_period_ns: None, + partition_template: Default::default(), + }; + let schema2 = NamespaceSchema { + id: NamespaceId::new(1), + tables: BTreeMap::from([( + String::from("foo"), + TableSchema { + id: TableId::new(1), + columns: ColumnsByName::default(), + partition_template: Default::default(), + }, + )]), + max_tables: MaxTables::try_from(42).unwrap(), + max_columns_per_table: MaxColumnsPerTable::try_from(4).unwrap(), + retention_period_ns: None, + partition_template: Default::default(), + }; + assert!(schema1.size() < schema2.size()); + } + + #[test] + #[should_panic = "timestamp wraparound"] + fn test_timestamp_wraparound_panic_add_i64() { + let _ = Timestamp::new(i64::MAX) + 1; + } + + #[test] + #[should_panic = "timestamp wraparound"] + fn test_timestamp_wraparound_panic_sub_i64() { + let _ = Timestamp::new(i64::MIN) - 1; + } + + #[test] + #[should_panic = "timestamp wraparound"] + fn test_timestamp_wraparound_panic_add_timestamp() { + let _ = Timestamp::new(i64::MAX) + Timestamp::new(1); + } + + #[test] + #[should_panic = "timestamp wraparound"] + fn test_timestamp_wraparound_panic_sub_timestamp() { + let _ = Timestamp::new(i64::MIN) - Timestamp::new(1); + } + + #[test] + fn test_timestamprange_start_after_end() { + let tr = TimestampRange::new(2, 1); + assert_eq!(tr.start(), 1); + assert_eq!(tr.end(), 1); + } +} diff --git a/data_types/src/namespace_name.rs b/data_types/src/namespace_name.rs new file mode 100644 index 0000000..e9e2e58 --- /dev/null +++ b/data_types/src/namespace_name.rs @@ -0,0 +1,350 @@ +use std::{borrow::Cow, ops::RangeInclusive}; + +use thiserror::Error; + +/// Length constraints for a [`NamespaceName`] name. +/// +/// A `RangeInclusive` is a closed interval, covering [1, 64] +const LENGTH_CONSTRAINT: RangeInclusive = 1..=64; + +/// Allowlist of chars for a [`NamespaceName`] name. +/// +/// '/' | '_' | '-' are utilized by the platforms. +fn is_allowed(c: char) -> bool { + c.is_alphanumeric() || matches!(c, '/' | '_' | '-') +} + +/// Errors returned when attempting to construct a [`NamespaceName`] from an org +/// & bucket string pair. +#[derive(Debug, Error)] +pub enum OrgBucketMappingError { + /// An error returned when the org, or bucket string contains invalid + /// characters. + #[error("invalid namespace name: {0}")] + InvalidNamespaceName(#[from] NamespaceNameError), + + /// Either the org, or bucket is an empty string. + #[error("missing org/bucket value")] + NoOrgBucketSpecified, +} + +/// [`NamespaceName`] name validation errors. +#[derive(Debug, Error)] +pub enum NamespaceNameError { + /// The provided namespace name does not fall within the valid length of a + /// namespace. + #[error( + "namespace name {} length must be between {} and {} characters", + name, + LENGTH_CONSTRAINT.start(), + LENGTH_CONSTRAINT.end() + )] + LengthConstraint { + /// The user-provided namespace that failed validation. + name: String, + }, + + /// The provided namespace name contains an unacceptable character. + #[error( + "namespace name '{}' contains invalid character, character number {} \ + is not whitelisted", + name, + bad_char_offset + )] + BadChars { + /// The zero-indexed (multi-byte) character position that failed + /// validation. + bad_char_offset: usize, + /// The user-provided namespace that failed validation. + name: String, + }, +} + +/// A correctly formed namespace name. +/// +/// Using this wrapper type allows the consuming code to enforce the invariant +/// that only valid names are provided. +/// +/// This type derefs to a `str` and therefore can be used in place of anything +/// that is expecting a `str`: +/// +/// ```rust +/// # use data_types::NamespaceName; +/// fn print_namespace(s: &str) { +/// println!("namespace name: {}", s); +/// } +/// +/// let ns = NamespaceName::new("data").unwrap(); +/// print_namespace(&ns); +/// ``` +/// +/// But this is not reciprocal - functions that wish to accept only +/// pre-validated names can use `NamespaceName` as a parameter. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct NamespaceName<'a>(Cow<'a, str>); + +impl<'a> NamespaceName<'a> { + /// Create a new, valid NamespaceName. + pub fn new>>(name: T) -> Result { + let name: Cow<'a, str> = name.into(); + + if !LENGTH_CONSTRAINT.contains(&name.len()) { + return Err(NamespaceNameError::LengthConstraint { + name: name.to_string(), + }); + } + + // Validate the name contains only valid characters. + // + // NOTE: If changing these characters, please update the error message + // above. + if let Some(bad_char_offset) = name.chars().position(|c| !is_allowed(c)) { + return Err(NamespaceNameError::BadChars { + bad_char_offset, + name: name.to_string(), + }); + }; + + Ok(Self(name)) + } + + /// Borrow a string slice of the name. + pub fn as_str(&self) -> &str { + self.0.as_ref() + } + + /// Map an InfluxDB 2.X org & bucket into an IOx NamespaceName. + /// + /// This function ensures the mapping is unambiguous by encoding any + /// non-alphanumeric characters in both `org` and `bucket` in addition to + /// the validation performed in [`NamespaceName::new()`]. + pub fn from_org_and_bucket, B: AsRef>( + org: O, + bucket: B, + ) -> Result { + let org = org.as_ref(); + let bucket = bucket.as_ref(); + + if org.is_empty() || bucket.is_empty() { + return Err(OrgBucketMappingError::NoOrgBucketSpecified); + } + + Ok(Self::new(format!("{}_{}", org, bucket))?) + } + + /// Efficiently returns the string representation of this [`NamespaceName`]. + /// + /// If this [`NamespaceName`] contains an owned string, it is returned + /// without cloning. + pub fn into_string(self) -> String { + self.0.into_owned() + } +} + +impl<'a> std::convert::From> for String { + fn from(name: NamespaceName<'a>) -> Self { + name.0.to_string() + } +} + +impl<'a> std::convert::From<&NamespaceName<'a>> for String { + fn from(name: &NamespaceName<'a>) -> Self { + name.to_string() + } +} + +impl<'a> std::convert::TryFrom<&'a str> for NamespaceName<'a> { + type Error = NamespaceNameError; + + fn try_from(v: &'a str) -> Result { + Self::new(v) + } +} + +impl<'a> std::convert::TryFrom for NamespaceName<'a> { + type Error = NamespaceNameError; + + fn try_from(v: String) -> Result { + Self::new(v) + } +} + +impl<'a> std::ops::Deref for NamespaceName<'a> { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.as_str() + } +} + +impl<'a> AsRef<[u8]> for NamespaceName<'a> { + fn as_ref(&self) -> &[u8] { + self.as_str().as_bytes() + } +} + +impl<'a> std::fmt::Display for NamespaceName<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_org_bucket_map_db_ok() { + let got = NamespaceName::from_org_and_bucket("org", "bucket") + .expect("failed on valid DB mapping"); + + assert_eq!(got.as_str(), "org_bucket"); + assert_eq!(got.into_string(), "org_bucket"); + } + + #[test] + fn test_into_string() { + // Ref type str + assert_eq!( + NamespaceName::new("bananas").unwrap().into_string(), + "bananas" + ); + // Owned type string + assert_eq!( + NamespaceName::new("bananas".to_string()) + .unwrap() + .into_string(), + "bananas" + ); + } + + #[test] + fn test_org_bucket_map_db_contains_underscore() { + let got = NamespaceName::from_org_and_bucket("my_org", "bucket").unwrap(); + assert_eq!(got.as_str(), "my_org_bucket"); + + let got = NamespaceName::from_org_and_bucket("org", "my_bucket").unwrap(); + assert_eq!(got.as_str(), "org_my_bucket"); + + let got = NamespaceName::from_org_and_bucket("org", "my__bucket").unwrap(); + assert_eq!(got.as_str(), "org_my__bucket"); + + let got = NamespaceName::from_org_and_bucket("my_org", "my_bucket").unwrap(); + assert_eq!(got.as_str(), "my_org_my_bucket"); + } + + #[test] + fn test_org_bucket_map_db_contains_underscore_and_percent() { + let err = NamespaceName::from_org_and_bucket("my%5Forg", "bucket"); + assert!(matches!( + err, + Err(OrgBucketMappingError::InvalidNamespaceName { .. }) + )); + + let err = NamespaceName::from_org_and_bucket("my%5Forg_", "bucket"); + assert!(matches!( + err, + Err(OrgBucketMappingError::InvalidNamespaceName { .. }) + )); + } + + #[test] + fn test_bad_namespace_name_fails_validation() { + let err = NamespaceName::from_org_and_bucket("org", "bucket?"); + assert!(matches!( + err, + Err(OrgBucketMappingError::InvalidNamespaceName { .. }) + )); + + let err = NamespaceName::from_org_and_bucket("org!", "bucket"); + assert!(matches!( + err, + Err(OrgBucketMappingError::InvalidNamespaceName { .. }) + )); + } + + #[test] + fn test_empty_org_bucket() { + let err = NamespaceName::from_org_and_bucket("", "") + .expect_err("should fail with empty org/bucket valuese"); + assert!(matches!(err, OrgBucketMappingError::NoOrgBucketSpecified)); + } + + #[test] + fn test_deref() { + let db = NamespaceName::new("my_example_name").unwrap(); + assert_eq!(&*db, "my_example_name"); + } + + #[test] + fn test_too_short() { + let name = "".to_string(); + let got = NamespaceName::try_from(name).unwrap_err(); + + assert!(matches!( + got, + NamespaceNameError::LengthConstraint { name: _n } + )); + } + + #[test] + fn test_too_long() { + let name = "my_example_name_that_is_quite_a_bit_longer_than_allowed_even_though_database_names_can_be_quite_long_bananas".to_string(); + let got = NamespaceName::try_from(name).unwrap_err(); + + assert!(matches!( + got, + NamespaceNameError::LengthConstraint { name: _n } + )); + } + + #[test] + fn test_bad_chars_null() { + let got = NamespaceName::new("example\x00").unwrap_err(); + assert_eq!(got.to_string() , "namespace name 'example\x00' contains invalid character, character number 7 is not whitelisted"); + } + + #[test] + fn test_bad_chars_high_control() { + let got = NamespaceName::new("\u{007f}example").unwrap_err(); + assert_eq!(got.to_string() , "namespace name '\u{007f}example' contains invalid character, character number 0 is not whitelisted"); + } + + #[test] + fn test_bad_chars_tab() { + let got = NamespaceName::new("example\tdb").unwrap_err(); + assert_eq!(got.to_string() , "namespace name 'example\tdb' contains invalid character, character number 7 is not whitelisted"); + } + + #[test] + fn test_bad_chars_newline() { + let got = NamespaceName::new("my_example\ndb").unwrap_err(); + assert_eq!(got.to_string() , "namespace name 'my_example\ndb' contains invalid character, character number 10 is not whitelisted"); + } + + #[test] + fn test_bad_chars_whitespace() { + let got = NamespaceName::new("my_example db").unwrap_err(); + assert_eq!(got.to_string() , "namespace name 'my_example db' contains invalid character, character number 10 is not whitelisted"); + } + + #[test] + fn test_bad_chars_single_quote() { + let got = NamespaceName::new("my_example'db").unwrap_err(); + assert_eq!(got.to_string() , "namespace name 'my_example\'db' contains invalid character, character number 10 is not whitelisted"); + } + + #[test] + fn test_ok_chars() { + let db = + NamespaceName::new("my-example-db_with_underscores/and/fwd/slash/AndCaseSensitive") + .unwrap(); + assert_eq!( + &*db, + "my-example-db_with_underscores/and/fwd/slash/AndCaseSensitive" + ); + + let db = NamespaceName::new("a_ã_京").unwrap(); + assert_eq!(&*db, "a_ã_京"); + } +} diff --git a/data_types/src/partition.rs b/data_types/src/partition.rs new file mode 100644 index 0000000..eb09524 --- /dev/null +++ b/data_types/src/partition.rs @@ -0,0 +1,690 @@ +//! Types having to do with partitions. + +use super::{ColumnsByName, SortKeyIds, TableId, Timestamp}; + +use schema::sort::SortKey; +use sha2::Digest; +use std::{fmt::Display, sync::Arc}; +use thiserror::Error; + +/// Unique ID for a `Partition` during the transition from catalog-assigned sequential +/// `PartitionId`s to deterministic `PartitionHashId`s. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum TransitionPartitionId { + /// The old catalog-assigned sequential `PartitionId`s that are in the process of being + /// deprecated. + Deprecated(PartitionId), + /// The new deterministic, hash-based `PartitionHashId`s that will be the new way to identify + /// partitions. + Deterministic(PartitionHashId), +} + +impl TransitionPartitionId { + /// Create a [`TransitionPartitionId`] from a [`PartitionId`] and optional [`PartitionHashId`] + pub fn from_parts(id: PartitionId, hash_id: Option) -> Self { + match hash_id { + Some(x) => Self::Deterministic(x), + None => Self::Deprecated(id), + } + } + + /// Size in bytes including `self`. + pub fn size(&self) -> usize { + match self { + Self::Deprecated(_) => std::mem::size_of::(), + Self::Deterministic(id) => { + std::mem::size_of::() + id.size() - std::mem::size_of_val(id) + } + } + } +} + +impl<'a, R> sqlx::FromRow<'a, R> for TransitionPartitionId +where + R: sqlx::Row, + &'static str: sqlx::ColumnIndex, + PartitionId: sqlx::decode::Decode<'a, R::Database>, + PartitionId: sqlx::types::Type, + Option: sqlx::decode::Decode<'a, R::Database>, + Option: sqlx::types::Type, +{ + fn from_row(row: &'a R) -> sqlx::Result { + let partition_id: Option = row.try_get("partition_id")?; + let partition_hash_id: Option = row.try_get("partition_hash_id")?; + + let transition_partition_id = match (partition_id, partition_hash_id) { + (_, Some(hash_id)) => TransitionPartitionId::Deterministic(hash_id), + (Some(id), _) => TransitionPartitionId::Deprecated(id), + (None, None) => { + return Err(sqlx::Error::ColumnDecode { + index: "partition_id".into(), + source: "Both partition_id and partition_hash_id were NULL".into(), + }) + } + }; + + Ok(transition_partition_id) + } +} + +impl From<(PartitionId, Option<&PartitionHashId>)> for TransitionPartitionId { + fn from((partition_id, partition_hash_id): (PartitionId, Option<&PartitionHashId>)) -> Self { + Self::from_parts(partition_id, partition_hash_id.cloned()) + } +} + +impl std::fmt::Display for TransitionPartitionId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Deprecated(old_partition_id) => write!(f, "{}", old_partition_id.0), + Self::Deterministic(partition_hash_id) => write!(f, "{}", partition_hash_id), + } + } +} + +impl TransitionPartitionId { + /// Create a new `TransitionPartitionId::Deterministic` with the given table + /// ID and partition key. Provided to reduce typing and duplication a bit, + /// and because this variant should be most common now. + /// + /// This MUST NOT be used for partitions that are addressed using legacy / + /// deprecated catalog row IDs, which should use + /// [`TransitionPartitionId::Deprecated`] instead. + pub fn new(table_id: TableId, partition_key: &PartitionKey) -> Self { + Self::Deterministic(PartitionHashId::new(table_id, partition_key)) + } + + /// Create a new `TransitionPartitionId` for cases in tests where you need some value but the + /// value doesn't matter. Public and not test-only so that other crates' tests can use this. + pub fn arbitrary_for_testing() -> Self { + Self::new(TableId::new(0), &PartitionKey::from("arbitrary")) + } +} + +/// Errors deserialising protobuf representations of [`TransitionPartitionId`]. +#[derive(Debug, Error)] +pub enum PartitionIdProtoError { + /// The proto type does not contain an ID. + #[error("no id specified for partition id")] + NoId, + + /// The specified hash ID is invalid. + #[error(transparent)] + InvalidHashId(#[from] PartitionHashIdError), +} + +/// Serialise a [`TransitionPartitionId`] to a protobuf representation. +impl From + for generated_types::influxdata::iox::catalog::v1::PartitionIdentifier +{ + fn from(value: TransitionPartitionId) -> Self { + use generated_types::influxdata::iox::catalog::v1 as proto; + match value { + TransitionPartitionId::Deprecated(id) => proto::PartitionIdentifier { + id: Some(proto::partition_identifier::Id::CatalogId(id.get())), + }, + TransitionPartitionId::Deterministic(hash) => proto::PartitionIdentifier { + id: Some(proto::partition_identifier::Id::HashId( + hash.as_bytes().to_owned(), + )), + }, + } + } +} + +/// Deserialise a [`TransitionPartitionId`] from a protobuf representation. +impl TryFrom + for TransitionPartitionId +{ + type Error = PartitionIdProtoError; + + fn try_from( + value: generated_types::influxdata::iox::catalog::v1::PartitionIdentifier, + ) -> Result { + use generated_types::influxdata::iox::catalog::v1 as proto; + + let id = value.id.ok_or(PartitionIdProtoError::NoId)?; + + Ok(match id { + proto::partition_identifier::Id::CatalogId(v) => { + TransitionPartitionId::Deprecated(PartitionId::new(v)) + } + proto::partition_identifier::Id::HashId(hash) => { + TransitionPartitionId::Deterministic(PartitionHashId::try_from(hash.as_slice())?) + } + }) + } +} + +/// Unique ID for a `Partition` +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, sqlx::FromRow)] +#[sqlx(transparent)] +pub struct PartitionId(i64); + +#[allow(missing_docs)] +impl PartitionId { + pub const fn new(v: i64) -> Self { + Self(v) + } + pub fn get(&self) -> i64 { + self.0 + } +} + +impl std::fmt::Display for PartitionId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Defines a partition via an arbitrary string within a table within +/// a namespace. +/// +/// Implemented as a reference-counted string, serialisable to +/// the Postgres VARCHAR data type. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct PartitionKey(Arc); + +impl PartitionKey { + /// Returns true if this instance of [`PartitionKey`] is backed by the same + /// string storage as other. + pub fn ptr_eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } + + /// Returns underlying string. + pub fn inner(&self) -> &str { + &self.0 + } + + /// Returns the bytes of the inner string. + pub fn as_bytes(&self) -> &[u8] { + self.0.as_bytes() + } +} + +impl Display for PartitionKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +impl From for PartitionKey { + fn from(s: String) -> Self { + assert!(!s.is_empty()); + Self(s.into()) + } +} + +impl From<&str> for PartitionKey { + fn from(s: &str) -> Self { + assert!(!s.is_empty()); + Self(s.into()) + } +} + +impl sqlx::Type for PartitionKey { + fn type_info() -> sqlx::postgres::PgTypeInfo { + // Store this type as VARCHAR + sqlx::postgres::PgTypeInfo::with_name("VARCHAR") + } +} + +impl sqlx::Encode<'_, sqlx::Postgres> for PartitionKey { + fn encode_by_ref( + &self, + buf: &mut >::ArgumentBuffer, + ) -> sqlx::encode::IsNull { + <&str as sqlx::Encode>::encode(&self.0, buf) + } +} + +impl sqlx::Decode<'_, sqlx::Postgres> for PartitionKey { + fn decode( + value: >::ValueRef, + ) -> Result> { + Ok(Self( + >::decode(value)?.into(), + )) + } +} + +impl sqlx::Type for PartitionKey { + fn type_info() -> sqlx::sqlite::SqliteTypeInfo { + >::type_info() + } +} + +impl sqlx::Encode<'_, sqlx::Sqlite> for PartitionKey { + fn encode_by_ref( + &self, + buf: &mut >::ArgumentBuffer, + ) -> sqlx::encode::IsNull { + >::encode(self.0.to_string(), buf) + } +} + +impl sqlx::Decode<'_, sqlx::Sqlite> for PartitionKey { + fn decode( + value: >::ValueRef, + ) -> Result> { + Ok(Self( + >::decode(value)?.into(), + )) + } +} + +const PARTITION_HASH_ID_SIZE_BYTES: usize = 32; + +/// Uniquely identify a partition based on its table ID and partition key. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, sqlx::FromRow)] +#[sqlx(transparent)] +pub struct PartitionHashId(Arc<[u8; PARTITION_HASH_ID_SIZE_BYTES]>); + +impl std::fmt::Display for PartitionHashId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for byte in &*self.0 { + write!(f, "{:02x}", byte)?; + } + Ok(()) + } +} + +impl std::hash::Hash for PartitionHashId { + #[inline(always)] + fn hash(&self, state: &mut H) { + // the slice is already hashed, so we can be a bit more efficient: + // A hash of an object is technically only 64bits (this is what `Hasher::finish()` will produce). We assume that + // the SHA256 hash sum that was used to create the partition hash is good enough so that every 64-bit slice of + // it is a good hash candidate for the entire object. Hence, we only forward the first 64 bits to the hasher and + // call it a day. + + // There is currently no nice way to slice fixed-sized arrays, see: + // https://github.com/rust-lang/rust/issues/90091 + // + // So we implement this the hard way (to avoid some nasty panic paths that are quite expensive within a hash function). + // Conversion borrowed from https://github.com/rust-lang/rfcs/issues/1833#issuecomment-269509262 + const N_BYTES: usize = u64::BITS as usize / 8; + #[allow(clippy::assertions_on_constants)] + const _: () = assert!(PARTITION_HASH_ID_SIZE_BYTES >= N_BYTES); + let ptr = self.0.as_ptr() as *const [u8; N_BYTES]; + let sub: &[u8; N_BYTES] = unsafe { &*ptr }; + + state.write_u64(u64::from_ne_bytes(*sub)); + } +} + +/// Reasons bytes specified aren't a valid `PartitionHashId`. +#[derive(Debug, Error)] +#[allow(missing_copy_implementations)] +pub enum PartitionHashIdError { + /// The bytes specified were not valid + #[error("Could not interpret bytes as `PartitionHashId`: {data:?}")] + InvalidBytes { + /// The bytes used in the attempt to create a `PartitionHashId` + data: Vec, + }, +} + +impl TryFrom<&[u8]> for PartitionHashId { + type Error = PartitionHashIdError; + + fn try_from(data: &[u8]) -> Result { + let data: [u8; PARTITION_HASH_ID_SIZE_BYTES] = + data.try_into() + .map_err(|_| PartitionHashIdError::InvalidBytes { + data: data.to_vec(), + })?; + + Ok(Self(Arc::new(data))) + } +} + +impl PartitionHashId { + /// Create a new `PartitionHashId`. + pub fn new(table_id: TableId, partition_key: &PartitionKey) -> Self { + Self::from_raw(table_id, partition_key.as_bytes()) + } + + /// Create a new `PartitionHashId` + pub fn from_raw(table_id: TableId, key: &[u8]) -> Self { + // The hash ID of a partition is the SHA-256 of the `TableId` then the `PartitionKey`. This + // particular hash format was chosen so that there won't be collisions and this value can + // be used to uniquely identify a Partition without needing to go to the catalog to get a + // database-assigned ID. Given that users might set their `PartitionKey`, a cryptographic + // hash scoped by the `TableId` is needed to prevent malicious users from constructing + // collisions. This data will be held in memory across many services, so SHA-256 was chosen + // over SHA-512 to get the needed attributes in the smallest amount of space. + let mut inner = sha2::Sha256::new(); + + let table_bytes = table_id.to_be_bytes(); + // Avoiding collisions depends on the table ID's bytes always being a fixed size. So even + // though the current return type of `TableId::to_be_bytes` is `[u8; 8]`, we're asserting + // on the length here to make sure this code's assumptions hold even if the type of + // `TableId` changes in the future. + assert_eq!(table_bytes.len(), 8); + inner.update(table_bytes); + + inner.update(key); + Self(Arc::new(inner.finalize().into())) + } + + /// Read access to the bytes of the hash identifier. + pub fn as_bytes(&self) -> &[u8] { + self.0.as_ref() + } + + /// Size in bytes including `Self`. + pub fn size(&self) -> usize { + std::mem::size_of::() + self.0.len() + } + + /// Create a new `PartitionHashId` for cases in tests where you need some value but the value + /// doesn't matter. Public and not test-only so that other crates' tests can use this. + pub fn arbitrary_for_testing() -> Self { + Self::new(TableId::new(0), &PartitionKey::from("arbitrary")) + } +} + +impl<'q> sqlx::encode::Encode<'q, sqlx::Postgres> for &'q PartitionHashId { + fn encode_by_ref(&self, buf: &mut sqlx::postgres::PgArgumentBuffer) -> sqlx::encode::IsNull { + buf.extend_from_slice(self.0.as_ref()); + + sqlx::encode::IsNull::No + } +} + +impl<'q> sqlx::encode::Encode<'q, sqlx::Sqlite> for &'q PartitionHashId { + fn encode_by_ref( + &self, + args: &mut Vec>, + ) -> sqlx::encode::IsNull { + args.push(sqlx::sqlite::SqliteArgumentValue::Blob( + std::borrow::Cow::Borrowed(self.0.as_ref()), + )); + + sqlx::encode::IsNull::No + } +} + +impl<'r, DB: ::sqlx::Database> ::sqlx::decode::Decode<'r, DB> for PartitionHashId +where + &'r [u8]: sqlx::Decode<'r, DB>, +{ + fn decode( + value: >::ValueRef, + ) -> ::std::result::Result< + Self, + ::std::boxed::Box< + dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync, + >, + > { + let data = <&[u8] as ::sqlx::decode::Decode<'r, DB>>::decode(value)?; + let data: [u8; PARTITION_HASH_ID_SIZE_BYTES] = data.try_into()?; + Ok(Self(Arc::new(data))) + } +} + +impl<'r, DB: ::sqlx::Database> ::sqlx::Type for PartitionHashId +where + &'r [u8]: ::sqlx::Type, +{ + fn type_info() -> DB::TypeInfo { + <&[u8] as ::sqlx::Type>::type_info() + } +} + +/// Data object for a partition. The combination of table and key are unique (i.e. only one record +/// can exist for each combo) +#[derive(Debug, Clone, PartialEq, Eq, sqlx::FromRow, Hash)] +pub struct Partition { + /// the id of the partition + pub id: PartitionId, + /// The unique hash derived from the table ID and partition key, if available. This will become + /// required when partitions without the value have aged out. + hash_id: Option, + /// the table the partition is under + pub table_id: TableId, + /// the string key of the partition + pub partition_key: PartitionKey, + + /// Vector of column IDs that describes how *every* parquet file + /// in this [`Partition`] is sorted. The sort key contains all the + /// primary key (PK) columns that have been persisted, and nothing + /// else. The PK columns are all `tag` columns and the `time` + /// column. + /// + /// Even though it is possible for both the unpersisted data + /// and/or multiple parquet files to contain different subsets of + /// columns, the partition's sort key is guaranteed to be + /// "compatible" across all files. Compatible means that the + /// parquet file is sorted in the same order as the partition + /// sort key after removing any missing columns. + /// + /// Partitions are initially created before any data is persisted + /// with an empty sort key. The partition sort key is updated as + /// needed when data is persisted to parquet files: both on the + /// first persist when the sort key is empty, as on subsequent + /// persist operations when new tags occur in newly inserted data. + /// + /// Updating inserts new columns into the existing sort key. The order + /// of the existing columns relative to each other is NOT changed. + /// + /// For example, updating `A,B,C` to either `A,D,B,C` or `A,B,C,D` + /// is legal. However, updating to `A,C,D,B` is not because the + /// relative order of B and C has been reversed. + sort_key_ids: SortKeyIds, + + /// The time at which the newest file of the partition is created + pub new_file_at: Option, +} + +impl Partition { + /// Create a new Partition data object from the given attributes. This constructor will take + /// care of computing the [`PartitionHashId`]. + /// + /// This is only appropriate to use in the catalog or in tests. + pub fn new_catalog_only( + id: PartitionId, + hash_id: Option, + table_id: TableId, + partition_key: PartitionKey, + sort_key_ids: SortKeyIds, + new_file_at: Option, + ) -> Self { + Self { + id, + hash_id, + table_id, + partition_key, + sort_key_ids, + new_file_at, + } + } + + /// If this partition has a `PartitionHashId` stored in the catalog, use that. Otherwise, use + /// the database-assigned `PartitionId`. + pub fn transition_partition_id(&self) -> TransitionPartitionId { + TransitionPartitionId::from((self.id, self.hash_id.as_ref())) + } + + /// The unique hash derived from the table ID and partition key, if it exists in the catalog. + pub fn hash_id(&self) -> Option<&PartitionHashId> { + self.hash_id.as_ref() + } + + /// The sort key IDs, if the sort key has been set + pub fn sort_key_ids(&self) -> Option<&SortKeyIds> { + if self.sort_key_ids.is_empty() { + None + } else { + Some(&self.sort_key_ids) + } + } + + /// The sort key containing the column names found in the specified column map. + /// + /// # Panics + /// + /// Will panic if an ID isn't found in the column map. + pub fn sort_key(&self, columns_by_name: &ColumnsByName) -> Option { + self.sort_key_ids() + .map(|sort_key_ids| sort_key_ids.to_sort_key(columns_by_name)) + } + + /// Change the sort key IDs to the given sort key IDs. This should only be used in the + /// in-memory catalog or in tests; all other sort key updates should go through the catalog + /// functions. + pub fn set_sort_key_ids(&mut self, sort_key_ids: &SortKeyIds) { + self.sort_key_ids = sort_key_ids.clone(); + } +} + +#[cfg(test)] +pub(crate) mod tests { + use std::hash::{Hash, Hasher}; + + use super::*; + + use assert_matches::assert_matches; + use proptest::{prelude::*, proptest}; + + /// A fixture test asserting the deterministic partition ID generation + /// algorithm outputs a fixed value, preventing accidental changes to the + /// derived ID. + /// + /// This hash byte value MUST NOT change for the lifetime of a cluster + /// (though the encoding used in this test can). + #[test] + fn display_partition_hash_id_in_hex() { + let partition_hash_id = + PartitionHashId::new(TableId::new(5), &PartitionKey::from("2023-06-08")); + + assert_eq!( + "ebd1041daa7c644c99967b817ae607bdcb754c663f2c415f270d6df720280f7a", + partition_hash_id.to_string() + ); + } + + prop_compose! { + /// Return an arbitrary [`TransitionPartitionId`] with a randomised ID + /// value. + pub fn arbitrary_partition_id()( + use_hash in any::(), + row_id in any::(), + hash_id in any::<[u8; PARTITION_HASH_ID_SIZE_BYTES]>() + ) -> TransitionPartitionId { + match use_hash { + true => TransitionPartitionId::Deterministic(PartitionHashId(hash_id.into())), + false => TransitionPartitionId::Deprecated(PartitionId::new(row_id)), + } + } + } + + proptest! { + #[test] + fn partition_hash_id_representations( + table_id in 0..i64::MAX, + partition_key in ".+", + ) { + let table_id = TableId::new(table_id); + let partition_key = PartitionKey::from(partition_key); + + let partition_hash_id = PartitionHashId::new(table_id, &partition_key); + + // ID generation MUST be deterministic. + let partition_hash_id_regenerated = PartitionHashId::new(table_id, &partition_key); + assert_eq!(partition_hash_id, partition_hash_id_regenerated); + + // ID generation MUST be collision resistant; different inputs -> different IDs + let other_table_id = TableId::new(table_id.get().wrapping_add(1)); + let different_partition_hash_id = PartitionHashId::new(other_table_id, &partition_key); + assert_ne!(partition_hash_id, different_partition_hash_id); + + // The bytes of the partition hash ID are stored in the catalog and sent from the + // ingesters to the queriers. We should be able to round-trip through bytes. + let bytes_representation = partition_hash_id.as_bytes(); + assert_eq!(bytes_representation.len(), 32); + let from_bytes = PartitionHashId::try_from(bytes_representation).unwrap(); + assert_eq!(from_bytes, partition_hash_id); + + // The hex string of the bytes is used in the Parquet file path in object storage, and + // should always be the same length. + let string_representation = partition_hash_id.to_string(); + assert_eq!(string_representation.len(), 64); + + // While nothing is currently deserializing the hex string to create `PartitionHashId` + // instances, it should work because there's nothing preventing it either. + let bytes_from_string = hex::decode(string_representation).unwrap(); + let from_string = PartitionHashId::try_from(&bytes_from_string[..]).unwrap(); + assert_eq!(from_string, partition_hash_id); + } + + /// Assert a [`TransitionPartitionId`] is round-trippable through proto + /// serialisation. + #[test] + fn prop_partition_id_proto_round_trip(id in arbitrary_partition_id()) { + use generated_types::influxdata::iox::catalog::v1 as proto; + + // Encoding is infallible + let encoded = proto::PartitionIdentifier::from(id.clone()); + + // Decoding a valid ID is infallible. + let decoded = TransitionPartitionId::try_from(encoded).unwrap(); + + // The deserialised value must match the input (round trippable) + assert_eq!(decoded, id); + } + } + + #[test] + fn test_proto_no_id() { + use generated_types::influxdata::iox::catalog::v1 as proto; + + let msg = proto::PartitionIdentifier { id: None }; + + assert_matches!( + TransitionPartitionId::try_from(msg), + Err(PartitionIdProtoError::NoId) + ); + } + + #[test] + fn test_proto_bad_hash() { + use generated_types::influxdata::iox::catalog::v1 as proto; + + let msg = proto::PartitionIdentifier { + id: Some(proto::partition_identifier::Id::HashId(vec![42])), + }; + + assert_matches!( + TransitionPartitionId::try_from(msg), + Err(PartitionIdProtoError::InvalidHashId(_)) + ); + } + + #[test] + fn test_hash_partition_hash_id() { + let id = PartitionHashId::arbitrary_for_testing(); + + let mut hasher = TestHasher::default(); + id.hash(&mut hasher); + + assert_eq!(hasher.written, vec![id.as_bytes()[..8].to_vec()],); + } + + #[derive(Debug, Default)] + struct TestHasher { + written: Vec>, + } + + impl Hasher for TestHasher { + fn finish(&self) -> u64 { + unimplemented!() + } + + fn write(&mut self, bytes: &[u8]) { + self.written.push(bytes.to_vec()); + } + } +} diff --git a/data_types/src/partition_template.rs b/data_types/src/partition_template.rs new file mode 100644 index 0000000..bbd0633 --- /dev/null +++ b/data_types/src/partition_template.rs @@ -0,0 +1,1949 @@ +//! Partition templating with per-namespace & table override types. +//! +//! The override types utilise per-entity wrappers for type safety, ensuring a +//! namespace override is not used in a table (and vice versa), as well as to +//! ensure the correct chain of inheritance is adhered to at compile time. +//! +//! A partitioning template is resolved by evaluating the following (in order of +//! precedence): +//! +//! 1. Table name override, if specified. +//! 2. Namespace name override, if specified. +//! 3. Default partitioning scheme (YYYY-MM-DD) +//! +//! Each of the [`NamespacePartitionTemplateOverride`] and +//! [`TablePartitionTemplateOverride`] stores only the override, if provided, +//! and implicitly resolves to the default partitioning scheme if no override is +//! specified (indicated by the presence of [`Option::None`] in the wrapper). +//! +//! ## Default Partition Key +//! +//! The default partition key format is specified by [`PARTITION_BY_DAY_PROTO`], +//! with a template consisting of a single part: a YYYY-MM-DD representation of +//! the time row timestamp. +//! +//! ## Partition Key Format +//! +//! Should a partition template be used that generates a partition key +//! containing more than one part, those parts are delimited by the `|` +//! character ([`PARTITION_KEY_DELIMITER`]), chosen to be an unusual character +//! that is unlikely to occur in user-provided column values in order to +//! minimise the need to encode the value in the common case, while still +//! providing legible / printable keys. Should the user-provided column value +//! contain the `|` key, it is [percent encoded] (in addition to `!` below, and +//! the `%` character itself) to prevent ambiguity. +//! +//! It is an invariant that the resulting partition key derived from a given +//! template has the same number and ordering of parts. +//! +//! If the partition key template references a [`TemplatePart::TagValue`] column +//! that is not present in the row, a single `!` is inserted, indicating a NULL +//! template key part. If the value of the part is an empty string (""), a `^` +//! is inserted to ensure a non-empty partition key is always generated. Like +//! the `|` key above, any occurrence of these characters in a user-provided +//! column value is percent encoded. +//! +//! Because this serialisation format can be unambiguously reversed, the +//! [`build_column_values()`] function can be used to obtain the set of +//! [`TemplatePart::TagValue`] the key was constructed from. +//! +//! ### Value Truncation +//! +//! Partition key parts are limited to, at most, 200 bytes in length +//! ([`PARTITION_KEY_MAX_PART_LEN`]). If any single partition key part exceeds +//! this length limit, it is truncated and the truncation marker `#` +//! ([`PARTITION_KEY_PART_TRUNCATED`]) is appended. +//! +//! When rebuilding column values using [`build_column_values()`], a truncated +//! key part yields [`ColumnValue::Prefix`], which can only be used for prefix +//! matching - equality matching against a string always returns false. +//! +//! Two considerations must be made when truncating the generated key: +//! +//! * The string may contain encoded sequences in the form %XX, and the string +//! should not be split within an encoded sequence, or decoding the string +//! will fail. +//! +//! * This may be a unicode string - what the user might consider a "character" +//! may in fact be multiple unicode code-points, each of which may span +//! multiple bytes. +//! +//! Slicing a unicode code-point in half may lead to an invalid UTF-8 string, +//! which will prevent it from being used in Rust (and likely many other +//! languages/systems). Because partition keys are represented as strings and +//! not bytes, splitting a code-point in half MUST be avoided. +//! +//! Further to this, a sequence of multiple code-points can represent a single +//! "character" - this is called a grapheme. For example, the representation of +//! the Tamil "ni" character "நி" is composed of two multi-byte code-points; the +//! Tamil letter "na" which renders as "ந" and the vowel sign "ி", each 3 bytes +//! long. If split after the first 3 bytes, the compound "ni" character will be +//! incorrectly rendered as the single "na"/"ந" character. +//! +//! Depending on what the consumer of the split string considers a character, +//! prefix/equality matching may produce differing results if a grapheme is +//! split. If the caller performs a byte-wise comparison, everything is fine - +//! if they perform a "character" comparison, then the equality may be lost +//! depending on what they consider a character. +//! +//! Therefore this implementation takes the conservative approach of never +//! splitting code-points (for UTF-8 correctness) nor graphemes for simplicity +//! and compatibility for the consumer. This may be relaxed in the future to +//! allow splitting graphemes, but by being conservative we give ourselves this +//! option - we can't easily do the reverse! +//! +//! ## Part Limit & Maximum Key Size +//! +//! The number of parts in a partition template is limited to 8 +//! ([`MAXIMUM_NUMBER_OF_TEMPLATE_PARTS`]), validated at creation time. +//! +//! Together with the above value truncation, this bounds the maximum length of +//! a partition key to 1,607 bytes (1.57 KiB). +//! +//! ### Reserved Characters +//! +//! Reserved characters that are percent encoded (in addition to non-ASCII +//! characters), and their meaning: +//! +//! * `|` - partition key part delimiter ([`PARTITION_KEY_DELIMITER`]) +//! * `!` - NULL/missing partition key part ([`PARTITION_KEY_VALUE_NULL`]) +//! * `^` - empty string partition key part ([`PARTITION_KEY_VALUE_EMPTY`]) +//! * `#` - key part truncation marker ([`PARTITION_KEY_PART_TRUNCATED`]) +//! * `%` - required for unambiguous reversal of percent encoding +//! +//! These characters are defined in [`ENCODED_PARTITION_KEY_CHARS`] and chosen +//! due to their low likelihood of occurrence in user-provided column values. +//! +//! ### Reserved Tag Values +//! +//! Reserved tag values that cannot be used: +//! +//! * `time` - The time column has special meaning and is covered by strftime +//! formatters ([`TAG_VALUE_KEY_TIME`]) +//! +//! ### Examples +//! +//! When using the partition template below: +//! +//! ```text +//! [ +//! TemplatePart::TimeFormat("%Y"), +//! TemplatePart::TagValue("a"), +//! TemplatePart::TagValue("b"), +//! TemplatePart::Bucket("c", 10) +//! ] +//! ``` +//! +//! The following partition keys are derived: +//! +//! * `time=2023-01-01, a=bananas, b=plátanos, c=ananas` -> `2023|bananas|plátanos|5` +//! * `time=2023-01-01, b=plátanos` -> `2023|!|plátanos|!` +//! * `time=2023-01-01, another=cat, b=plátanos` -> `2023|!|plátanos|!` +//! * `time=2023-01-01` -> `2023|!|!|!` +//! * `time=2023-01-01, a=cat|dog, b=!, c=!` -> `2023|cat%7Cdog|%21|8` +//! * `time=2023-01-01, a=%50, c=%50` -> `2023|%2550|!|9` +//! * `time=2023-01-01, a=, c=` -> `2023|^|!|0` +//! * `time=2023-01-01, a=` -> `2023|#|!|!` +//! +//! When using the default partitioning template (YYYY-MM-DD) there is no +//! encoding necessary, as the derived partition key contains a single part, and +//! no reserved characters. +//! +//! [percent encoded]: https://url.spec.whatwg.org/#percent-encoded-bytes +use std::{ + borrow::Cow, + fmt::{Display, Formatter}, + ops::Range, + sync::Arc, +}; + +use chrono::{ + format::{Numeric, StrftimeItems}, + DateTime, Days, Months, Utc, +}; +use generated_types::influxdata::iox::partition_template::v1 as proto; +use murmur3::murmur3_32; +use once_cell::sync::Lazy; +use percent_encoding::{percent_decode_str, AsciiSet, CONTROLS}; +use schema::TIME_COLUMN_NAME; +use thiserror::Error; + +/// Reasons a user-specified partition template isn't valid. +#[derive(Debug, Error)] +#[allow(missing_copy_implementations)] +pub enum ValidationError { + /// The partition template didn't define any parts. + #[error("Custom partition template must have at least one part")] + NoParts, + + /// The partition template exceeded the maximum allowed number of parts. + #[error( + "Custom partition template specified {specified} parts. \ + Partition templates may have a maximum of {MAXIMUM_NUMBER_OF_TEMPLATE_PARTS} parts." + )] + TooManyParts { + /// The number of template parts that were present in the user-provided custom partition + /// template. + specified: usize, + }, + + /// The partition template defines a [`TimeFormat`] part, but the + /// provided strftime formatter is invalid. + /// + /// [`TimeFormat`]: [`proto::template_part::Part::TimeFormat`] + #[error("invalid strftime format in partition template: {0}")] + InvalidStrftime(String), + + /// The partition template defines a [`TagValue`] part or [`Bucket`] part, + /// but the provided tag name value is invalid. + /// + /// [`TagValue`]: [`proto::template_part::Part::TagValue`] + /// [`Bucket`]: [`proto::template_part::Part::Bucket`] + #[error("invalid tag name value in partition template: {0}")] + InvalidTagValue(String), + + /// The partition template defines a [`Bucket`] part, but the provided + /// number of buckets is invalid. + /// + /// [`Bucket`]: [`proto::template_part::Part::Bucket`] + #[error( + "number of buckets in partition template must be in range \ + [{ALLOWED_BUCKET_QUANTITIES:?}), number specified: {0}" + )] + InvalidNumberOfBuckets(u32), + + /// The partition template defines a [`TagValue`] or [`Bucket`] part + /// which repeats a tag name used in another [`TagValue`] or [`Bucket`] part. + /// This is not allowed + /// + /// [`TagValue`]: [`proto::template_part::Part::TagValue`] + /// [`Bucket`]: [`proto::template_part::Part::Bucket`] + #[error("tag name value cannot be repeated in partition template: {0}")] + RepeatedTagValue(String), +} + +/// The maximum number of template parts a custom partition template may specify, to limit the +/// amount of space in the catalog used by the custom partition template and the partition keys +/// created with it. +pub const MAXIMUM_NUMBER_OF_TEMPLATE_PARTS: usize = 8; + +/// The sentinel character used to delimit partition key parts in the partition +/// key string. +pub const PARTITION_KEY_DELIMITER: char = '|'; + +/// The sentinel character used to indicate an empty string partition key part +/// in the partition key string. +pub const PARTITION_KEY_VALUE_EMPTY: char = '^'; + +/// The `str` form of the [`PARTITION_KEY_VALUE_EMPTY`] character. +pub const PARTITION_KEY_VALUE_EMPTY_STR: &str = "^"; + +/// The sentinel character used to indicate a missing partition key part in the +/// partition key string. +pub const PARTITION_KEY_VALUE_NULL: char = '!'; + +/// The `str` form of the [`PARTITION_KEY_VALUE_NULL`] character. +pub const PARTITION_KEY_VALUE_NULL_STR: &str = "!"; + +/// The maximum permissible length of a partition key part, after encoding +/// reserved & non-ASCII characters. +pub const PARTITION_KEY_MAX_PART_LEN: usize = 200; + +/// The truncation sentinel character, used to explicitly identify a partition +/// key as having been truncated. +/// +/// Truncated partition key parts can only be used for prefix matching, and +/// yield a [`ColumnValue::Prefix`] from [`build_column_values()`]. +pub const PARTITION_KEY_PART_TRUNCATED: char = '#'; + +/// The reserved tag value key for the `time` column, which is reserved as +/// a specifically formatted column for the time associated with any given +/// data point. +pub const TAG_VALUE_KEY_TIME: &str = "time"; + +/// The range of bucket quantities allowed for [`Bucket`] template parts. +/// +/// [`Bucket`]: [`proto::template_part::Part::Bucket`] +pub const ALLOWED_BUCKET_QUANTITIES: Range = Range { + start: 1, + end: 100_000, +}; + +/// The minimal set of characters that must be encoded during partition key +/// generation when they form part of a partition key part, in order to be +/// unambiguously reversible. +/// +/// See module-level documentation & [`build_column_values()`]. +pub const ENCODED_PARTITION_KEY_CHARS: AsciiSet = CONTROLS + .add(PARTITION_KEY_DELIMITER as u8) + .add(PARTITION_KEY_VALUE_NULL as u8) + .add(PARTITION_KEY_VALUE_EMPTY as u8) + .add(PARTITION_KEY_PART_TRUNCATED as u8) + .add(b'%'); // Required for reversible unambiguous encoding + +/// Allocationless and protobufless access to the parts of a template needed to +/// actually do partitioning. +#[derive(Debug, Clone)] +pub enum TemplatePart<'a> { + /// A tag-value partition part. + /// + /// Specifies the name of the tag column. + TagValue(&'a str), + + /// A strftime formatter. + /// + /// Specifies the formatter spec applied to the [`TIME_COLUMN_NAME`] column. + TimeFormat(&'a str), + + /// A bucketing partition part. + /// + /// Specifies the name of the tag column used to derive which of the `n` + /// buckets the data belongs in, through the mechanism implemented by the + /// [`bucket_for_tag_value`] function. + Bucket(&'a str, u32), +} + +/// The default partitioning scheme is by each day according to the "time" column. +pub static PARTITION_BY_DAY_PROTO: Lazy> = Lazy::new(|| { + Arc::new(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat( + "%Y-%m-%d".to_owned(), + )), + }], + }) +}); + +// This applies murmur3 32 bit hashing to the tag value string, as Iceberg would. +// +// * +fn iceberg_hash(tag_value: &str) -> u32 { + murmur3_32(&mut tag_value.as_bytes(), 0).expect("read of tag value string must never error") +} + +/// Hash bucket the provided tag value to a bucket ID in the range `[0,num_buckets)`. +/// +/// This applies murmur3 32 bit hashing to the tag value string, zero-ing the sign bit +/// then modulo assigning it to a bucket as Iceberg would. +/// +/// * +/// * +/// +/// +/// # Panics +/// +/// If `num_buckets` is zero, this will panic. Validation MUST prevent +/// [`TemplatePart::Bucket`] from being constructed with a zero bucket count. It just +/// makes no sense and shouldn't need to be checked here. +#[inline(always)] +pub fn bucket_for_tag_value(tag_value: &str, num_buckets: u32) -> u32 { + // Hash the tag value as iceberg would. + let hash = iceberg_hash(tag_value); + // Then bucket it as iceberg would, by removing the sign bit from the + // 32 bit murmur hash and modulo by the number of buckets to assign + // across. + (hash & i32::MAX as u32) % num_buckets +} + +/// A partition template specified by a namespace record. +/// +/// Internally this type is [`None`] when no namespace-level override is +/// specified, resulting in the default being used. +#[derive(Debug, PartialEq, Clone, Default, sqlx::Type, Hash)] +#[sqlx(transparent, no_pg_array)] +pub struct NamespacePartitionTemplateOverride(Option); + +impl NamespacePartitionTemplateOverride { + /// A const "default" impl for testing. + pub const fn const_default() -> Self { + Self(None) + } + + /// Return the protobuf representation of this template. + pub fn as_proto(&self) -> Option<&proto::PartitionTemplate> { + self.0.as_ref().map(|v| v.inner()) + } +} + +impl TryFrom for NamespacePartitionTemplateOverride { + type Error = ValidationError; + + fn try_from(partition_template: proto::PartitionTemplate) -> Result { + Ok(Self(Some(serialization::Wrapper::try_from( + partition_template, + )?))) + } +} + +/// A partition template specified by a table record. +#[derive(Debug, PartialEq, Eq, Clone, Default, sqlx::Type, Hash)] +#[sqlx(transparent, no_pg_array)] +pub struct TablePartitionTemplateOverride(Option); + +impl TablePartitionTemplateOverride { + /// When a table is being explicitly created, the creation request might have contained a + /// custom partition template for that table. If the custom partition template is present, use + /// it. Otherwise, use the namespace's partition template. + /// + /// # Errors + /// + /// This function will return an error if the custom partition template specified is invalid. + pub fn try_new( + custom_table_template: Option, + namespace_template: &NamespacePartitionTemplateOverride, + ) -> Result { + match (custom_table_template, namespace_template.0.as_ref()) { + (Some(table_proto), _) => { + Ok(Self(Some(serialization::Wrapper::try_from(table_proto)?))) + } + (None, Some(namespace_serialization_wrapper)) => { + Ok(Self(Some(namespace_serialization_wrapper.clone()))) + } + (None, None) => Ok(Self(None)), + } + } + + /// Returns the number of parts in this template. + #[allow(clippy::len_without_is_empty)] // Senseless - there must always be >0 parts. + pub fn len(&self) -> usize { + self.parts().count() + } + + /// Iterate through the protobuf parts and lend out what the `mutable_batch` crate needs to + /// build `PartitionKey`s. If this table doesn't have a custom template, use the application + /// default of partitioning by day. + pub fn parts(&self) -> impl Iterator> { + self.0 + .as_ref() + .map(|serialization_wrapper| serialization_wrapper.inner()) + .unwrap_or_else(|| &PARTITION_BY_DAY_PROTO) + .parts + .iter() + .flat_map(|part| part.part.as_ref()) + .map(|part| match part { + proto::template_part::Part::TagValue(value) => TemplatePart::TagValue(value), + proto::template_part::Part::TimeFormat(fmt) => TemplatePart::TimeFormat(fmt), + proto::template_part::Part::Bucket(proto::Bucket { + tag_name, + num_buckets, + }) => TemplatePart::Bucket(tag_name, *num_buckets), + }) + } + + /// Size in bytes, including `self`. + /// + /// This accounts for the entire allocation of this object, even when it shared (via an internal [`Arc`]). + pub fn size(&self) -> usize { + std::mem::size_of_val(self) + + self + .0 + .as_ref() + .map(|wrapper| { + let inner = wrapper.inner(); + + // inner is wrapped into an Arc, so we need to account for that allocation + std::mem::size_of::() + + (inner.parts.capacity() * std::mem::size_of::()) + + inner + .parts + .iter() + .map(|part| { + part.part + .as_ref() + .map(|part| match part { + proto::template_part::Part::TagValue(s) => s.capacity(), + proto::template_part::Part::TimeFormat(s) => s.capacity(), + proto::template_part::Part::Bucket(proto::Bucket { + tag_name, + num_buckets: _, + }) => tag_name.capacity() + std::mem::size_of::(), + }) + .unwrap_or_default() + }) + .sum::() + }) + .unwrap_or_default() + } + + /// Return the protobuf representation of this template. + pub fn as_proto(&self) -> Option<&proto::PartitionTemplate> { + self.0.as_ref().map(|v| v.inner()) + } +} + +/// Display the serde_json representation so that the output +/// can be copy/pasted into CLI tools, etc as the partition +/// template is specified as JSON +impl Display for TablePartitionTemplateOverride { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + self.as_proto() + .map(|proto| serde_json::to_string(proto) + .expect("serialization should be infallible")) + .unwrap_or_default() + ) + } +} + +impl TryFrom> for TablePartitionTemplateOverride { + type Error = ValidationError; + + fn try_from(p: Option) -> Result { + Ok(Self(p.map(serialization::Wrapper::try_from).transpose()?)) + } +} + +/// This manages the serialization/deserialization of the `proto::PartitionTemplate` type to and +/// from the database through `sqlx` for the `NamespacePartitionTemplateOverride` and +/// `TablePartitionTemplateOverride` types. It's an internal implementation detail to minimize code +/// duplication. +mod serialization { + use super::{ + ValidationError, ALLOWED_BUCKET_QUANTITIES, MAXIMUM_NUMBER_OF_TEMPLATE_PARTS, + TAG_VALUE_KEY_TIME, + }; + use chrono::{format::StrftimeItems, Utc}; + use generated_types::influxdata::iox::partition_template::v1 as proto; + use std::{collections::HashSet, fmt::Write, sync::Arc}; + + #[derive(Debug, Clone, PartialEq, Hash)] + pub struct Wrapper(Arc); + + impl Wrapper { + /// Read access to the inner proto + pub fn inner(&self) -> &proto::PartitionTemplate { + &self.0 + } + + /// THIS IS FOR TESTING PURPOSES ONLY AND SHOULD NOT BE USED IN PRODUCTION CODE. + /// + /// The application shouldn't be putting invalid templates into the database because all + /// creation of `Wrapper`s should be going through the + /// `TryFrom::try_from` constructor that rejects invalid + /// templates. However, that leaves the possibility of the database getting an invalid + /// template through some other means, and we want to be able to construct those easily in + /// tests to make sure code using partition templates can handle the unlikely possibility + /// of an invalid template in the database. + pub(super) fn for_testing_possibility_of_invalid_value_in_database( + proto: proto::PartitionTemplate, + ) -> Self { + Self(Arc::new(proto)) + } + } + + // protobuf types normally don't implement `Eq`, but for this concrete type this is OK + impl Eq for Wrapper {} + + impl TryFrom for Wrapper { + type Error = ValidationError; + + fn try_from(partition_template: proto::PartitionTemplate) -> Result { + // There must be at least one part. + if partition_template.parts.is_empty() { + return Err(ValidationError::NoParts); + } + + // There may not be more than `MAXIMUM_NUMBER_OF_TEMPLATE_PARTS` parts. + let specified = partition_template.parts.len(); + if specified > MAXIMUM_NUMBER_OF_TEMPLATE_PARTS { + return Err(ValidationError::TooManyParts { specified }); + } + + let mut seen_tags: HashSet<&str> = HashSet::with_capacity(specified); + + // All time formats must be valid and tag values may not specify any + // restricted values. + for part in &partition_template.parts { + match &part.part { + Some(proto::template_part::Part::TimeFormat(fmt)) => { + // Empty is not a valid time format + if fmt.is_empty() { + return Err(ValidationError::InvalidStrftime(fmt.into())); + } + + // Chrono will panic during timestamp formatting if this + // formatter directive is used! + // + // An upper-case Z does not trigger the panic code path so + // is not checked for. + if fmt.contains("%#z") { + return Err(ValidationError::InvalidStrftime( + "%#z cannot be used".to_string(), + )); + } + + // Currently we can only tell whether a nonempty format is valid by trying + // to use it. See + let mut dev_null = String::new(); + write!( + dev_null, + "{}", + Utc::now().format_with_items(StrftimeItems::new(fmt)) + ) + .map_err(|_| ValidationError::InvalidStrftime(fmt.into()))? + } + Some(proto::template_part::Part::TagValue(value)) => { + // Empty is not a valid tag value + if value.is_empty() { + return Err(ValidationError::InvalidTagValue(value.into())); + } + + if value.contains(TAG_VALUE_KEY_TIME) { + return Err(ValidationError::InvalidTagValue(format!( + "{TAG_VALUE_KEY_TIME} cannot be used" + ))); + } + + if !seen_tags.insert(value.as_str()) { + return Err(ValidationError::RepeatedTagValue(value.into())); + } + } + Some(proto::template_part::Part::Bucket(proto::Bucket { + tag_name, + num_buckets, + })) => { + if tag_name.is_empty() { + return Err(ValidationError::InvalidTagValue(tag_name.into())); + } + + if tag_name.contains(TAG_VALUE_KEY_TIME) { + return Err(ValidationError::InvalidTagValue(format!( + "{TAG_VALUE_KEY_TIME} cannot be used" + ))); + } + + if !seen_tags.insert(tag_name.as_str()) { + return Err(ValidationError::RepeatedTagValue(tag_name.into())); + } + + if !ALLOWED_BUCKET_QUANTITIES.contains(num_buckets) { + return Err(ValidationError::InvalidNumberOfBuckets(*num_buckets)); + } + } + None => {} + } + } + + Ok(Self(Arc::new(partition_template))) + } + } + + impl sqlx::Type for Wrapper + where + sqlx::types::Json: sqlx::Type, + DB: sqlx::Database, + { + fn type_info() -> DB::TypeInfo { + as sqlx::Type>::type_info() + } + } + + impl<'q, DB> sqlx::Encode<'q, DB> for Wrapper + where + DB: sqlx::Database, + for<'b> sqlx::types::Json<&'b proto::PartitionTemplate>: sqlx::Encode<'q, DB>, + { + fn encode_by_ref( + &self, + buf: &mut >::ArgumentBuffer, + ) -> sqlx::encode::IsNull { + as sqlx::Encode<'_, DB>>::encode_by_ref( + &sqlx::types::Json(&self.0), + buf, + ) + } + } + + impl<'q, DB> sqlx::Decode<'q, DB> for Wrapper + where + DB: sqlx::Database, + sqlx::types::Json: sqlx::Decode<'q, DB>, + { + fn decode( + value: >::ValueRef, + ) -> Result> { + Ok(Self( + as sqlx::Decode<'_, DB>>::decode( + value, + )? + .0 + .into(), + )) + } + } +} + +/// The value of a column, reversed from a partition key. +/// +/// See [`build_column_values()`]. +#[derive(Debug, Clone, PartialEq)] +pub enum ColumnValue<'a> { + /// The inner value is the exact, unmodified input column value. + Identity(Cow<'a, str>), + + /// The inner value is a variable length prefix of the input column value. + /// + /// The string value is always guaranteed to be valid UTF-8. + /// + /// Attempting to equality match this variant against a string will always + /// be false - use [`ColumnValue::is_prefix_match_of()`] to prefix match + /// instead. + Prefix(Cow<'a, str>), + + /// Datetime. + Datetime { + /// Inclusive begin of the datatime partition range. + begin: DateTime, + + /// Exclusive end of the datatime partition range. + end: DateTime, + }, + + /// The inner value is the ID of the bucket selected through a modulo hash + /// of the input column value. + Bucket(u32), +} + +impl<'a> ColumnValue<'a> { + /// Returns true if `other` is a byte-wise prefix match of `self`. + /// + /// This method can be called for both [`ColumnValue::Identity`] and + /// [`ColumnValue::Prefix`]. + pub fn is_prefix_match_of(&self, other: T) -> bool + where + T: AsRef<[u8]>, + { + let this = match self { + ColumnValue::Identity(v) => v.as_bytes(), + ColumnValue::Prefix(v) => v.as_bytes(), + ColumnValue::Datetime { .. } | ColumnValue::Bucket(..) => { + return false; + } + }; + + other.as_ref().starts_with(this) + } +} + +impl<'a, T> PartialEq for ColumnValue<'a> +where + T: AsRef, +{ + fn eq(&self, other: &T) -> bool { + match self { + ColumnValue::Identity(v) => other.as_ref().eq(v.as_ref()), + ColumnValue::Prefix(_) => false, + ColumnValue::Datetime { .. } => false, + ColumnValue::Bucket(..) => false, + } + } +} + +/// Reverse a `partition_key` generated from the given partition key `template`, +/// reconstructing the set of tag values in the form of `(column name, column +/// value)` tuples that the `partition_key` was generated from. +/// +/// The `partition_key` MUST have been generated by `template`. +/// +/// Values are returned as a [`Cow`], avoiding the need for value copying if +/// they do not need decoding. See module docs for encoding/decoding. +/// +/// # Panics +/// +/// This method panics if a column value is not valid UTF8 after decoding, or +/// when a bucket ID is not valid (not a u32 or within the expected number of +/// buckets). +pub fn build_column_values<'a>( + template: &'a TablePartitionTemplateOverride, + partition_key: &'a str, +) -> impl Iterator)> { + // Exploded parts of the generated key on the "/" character. + // + // Any uses of the "/" character within the partition key's user-provided + // values are url encoded, so this is an unambiguous field separator. + let key_parts = partition_key.split(PARTITION_KEY_DELIMITER); + + // Obtain an iterator of template parts, from which the meaning of the key + // parts can be inferred. + let template_parts = template.parts(); + + // Invariant: the number of key parts generated from a given template always + // matches the number of template parts. + // + // The key_parts iterator is not an ExactSizeIterator, so an assert can't be + // placed here to validate this property. + + // Produce an iterator of (template_part, template_value) + template_parts + .zip(key_parts) + .filter_map(|(template, value)| { + if value == PARTITION_KEY_VALUE_NULL_STR { + None + } else { + match template { + TemplatePart::TagValue(col_name) => { + Some((col_name, parse_part_tag_value(value)?)) + } + TemplatePart::TimeFormat(format) => { + Some((TIME_COLUMN_NAME, parse_part_time_format(value, format)?)) + } + TemplatePart::Bucket(col_name, num_buckets) => { + Some((col_name, parse_part_bucket(value, num_buckets)?)) + } + } + } + }) +} + +fn parse_part_tag_value(value: &str) -> Option> { + // Perform re-mapping of sentinel values. + let value = match value { + PARTITION_KEY_VALUE_EMPTY_STR => { + // Re-map the empty string sentinel "^"" to an empty string + // value. + "" + } + _ => value, + }; + + // Reverse the urlencoding of all value parts + let decoded = percent_decode_str(value) + .decode_utf8() + .expect("invalid partition key part encoding"); + + // Inspect the final character in the string, pre-decoding, to + // determine if it has been truncated. + if value + .as_bytes() + .last() + .map(|v| *v == PARTITION_KEY_PART_TRUNCATED as u8) + .unwrap_or_default() + { + // Remove the truncation marker. + let len = decoded.len() - 1; + + // Only allocate if needed; re-borrow a subslice of `Cow::Borrowed` if not. + let column_cow = match decoded { + Cow::Borrowed(s) => Cow::Borrowed(&s[..len]), + Cow::Owned(s) => Cow::Owned(s[..len].to_string()), + }; + Some(ColumnValue::Prefix(column_cow)) + } else { + Some(ColumnValue::Identity(decoded)) + } +} + +fn parse_part_time_format(value: &str, format: &str) -> Option> { + use chrono::format::{parse, Item, Parsed}; + + let items = StrftimeItems::new(format); + + let mut parsed = Parsed::new(); + parse(&mut parsed, value, items.clone()).ok()?; + + // fill in defaults + let parsed = parsed_implicit_defaults(parsed)?; + + let begin = parsed.to_datetime_with_timezone(&Utc).ok()?; + + let mut end: Option> = None; + for item in items { + let item_end = match item { + Item::Literal(_) | Item::OwnedLiteral(_) | Item::Space(_) | Item::OwnedSpace(_) => None, + Item::Error => { + return None; + } + Item::Numeric(numeric, _pad) => { + match numeric { + Numeric::Year => Some(begin + Months::new(12)), + Numeric::Month => Some(begin + Months::new(1)), + Numeric::Day => Some(begin + Days::new(1)), + _ => { + // not supported + return None; + } + } + } + Item::Fixed(_) => { + // not implemented + return None; + } + }; + + end = match (end, item_end) { + (Some(a), Some(b)) => { + let a_d = a - begin; + let b_d = b - begin; + if a_d < b_d { + Some(a) + } else { + Some(b) + } + } + (None, Some(dt)) => Some(dt), + (Some(dt), None) => Some(dt), + (None, None) => None, + }; + } + + end.map(|end| ColumnValue::Datetime { begin, end }) +} + +fn parse_part_bucket(value: &str, num_buckets: u32) -> Option> { + // Parse the bucket ID from the given value string. + let bucket_id = value + .parse::() + .expect("invalid partition key bucket encoding"); + // Invariant: If the bucket ID (0 indexed) is greater than the number of + // buckets to spread data across the partition key is invalid. + assert!(bucket_id < num_buckets); + + Some(ColumnValue::Bucket(bucket_id)) +} + +fn parsed_implicit_defaults(mut parsed: chrono::format::Parsed) -> Option { + parsed.year?; + + if parsed.month.is_none() { + if parsed.day.is_some() { + return None; + } + + parsed.set_month(1).ok()?; + } + + if parsed.day.is_none() { + if parsed.hour_div_12.is_some() || parsed.hour_mod_12.is_some() { + return None; + } + + parsed.set_day(1).ok()?; + } + + if parsed.hour_div_12.is_none() || parsed.hour_mod_12.is_none() { + // consistency check + if parsed.hour_div_12.is_some() { + return None; + } + if parsed.hour_mod_12.is_some() { + return None; + } + + if parsed.minute.is_some() { + return None; + } + + parsed.set_hour(0).ok()?; + } + + if parsed.minute.is_none() { + if parsed.second.is_some() { + return None; + } + if parsed.nanosecond.is_some() { + return None; + } + + parsed.set_minute(0).ok()?; + } + + Some(parsed) +} + +/// In production code, the template should come from protobuf that is either from the database or +/// from a gRPC request. In tests, building protobuf is painful, so here's an easier way to create +/// a `TablePartitionTemplateOverride`. +/// +/// This deliberately goes around the validation of the templates so that tests can verify code +/// handles potentially invalid templates! +pub fn test_table_partition_override( + parts: Vec>, +) -> TablePartitionTemplateOverride { + let parts = parts + .into_iter() + .map(|part| { + let part = match part { + TemplatePart::TagValue(value) => proto::template_part::Part::TagValue(value.into()), + TemplatePart::TimeFormat(fmt) => proto::template_part::Part::TimeFormat(fmt.into()), + TemplatePart::Bucket(value, num_buckets) => { + proto::template_part::Part::Bucket(proto::Bucket { + tag_name: value.into(), + num_buckets, + }) + } + }; + + proto::TemplatePart { part: Some(part) } + }) + .collect(); + + let proto = proto::PartitionTemplate { parts }; + TablePartitionTemplateOverride(Some( + serialization::Wrapper::for_testing_possibility_of_invalid_value_in_database(proto), + )) +} + +#[cfg(test)] +mod tests { + use assert_matches::assert_matches; + use chrono::TimeZone; + use proptest::prelude::*; + use sqlx::Encode; + use test_helpers::assert_error; + + use super::*; + + #[test] + fn test_partition_template_to_string() { + let template_empty: TablePartitionTemplateOverride = + TablePartitionTemplateOverride::default(); + + let template: Vec> = + [TemplatePart::TimeFormat("%Y"), TemplatePart::TagValue("a")] + .into_iter() + .collect::>(); + let template: TablePartitionTemplateOverride = test_table_partition_override(template); + + assert_eq!(template_empty.to_string(), ""); + assert_eq!( + template.to_string(), + "{\"parts\":[{\"timeFormat\":\"%Y\"},{\"tagValue\":\"a\"}]}" + ); + } + + #[test] + fn test_max_partition_key_len() { + let max_len: usize = + // 8 parts, at most 200 bytes long. + (MAXIMUM_NUMBER_OF_TEMPLATE_PARTS * PARTITION_KEY_MAX_PART_LEN) + // 7 delimiting characters between parts. + + (MAXIMUM_NUMBER_OF_TEMPLATE_PARTS - 1); + + // If this changes, the module documentation should be changed too. + // + // This shouldn't change without consideration of primary key overlap as + // a result. + assert_eq!(max_len, 1_607, "update module docs please"); + } + + #[test] + fn empty_parts_is_invalid() { + let err = serialization::Wrapper::try_from(proto::PartitionTemplate { parts: vec![] }); + + assert_error!(err, ValidationError::NoParts); + } + + #[test] + fn more_than_8_parts_is_invalid() { + let err = serialization::Wrapper::try_from(proto::PartitionTemplate { + parts: vec![ + proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("region".into())), + }, + proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat("year-%Y".into())), + }, + proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("region".into())), + }, + proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat("year-%Y".into())), + }, + proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("region".into())), + }, + proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat("year-%Y".into())), + }, + proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("region".into())), + }, + proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat("year-%Y".into())), + }, + proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("region".into())), + }, + ], + }); + + assert_error!(err, ValidationError::TooManyParts { specified } if specified == 9); + } + + #[test] + fn repeated_tag_name_value_is_invalid() { + // Test [`TagValue`] + let err = serialization::Wrapper::try_from(proto::PartitionTemplate { + parts: vec![ + proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("bananas".into())), + }, + proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("bananas".into())), + }, + ], + }); + + assert_error!(err, ValidationError::RepeatedTagValue ( ref specified ) if specified == "bananas"); + + // Test [`Bucket`] + let err = serialization::Wrapper::try_from(proto::PartitionTemplate { + parts: vec![ + proto::TemplatePart { + part: Some(proto::template_part::Part::Bucket(proto::Bucket { + tag_name: "bananas".into(), + num_buckets: 42, + })), + }, + proto::TemplatePart { + part: Some(proto::template_part::Part::Bucket(proto::Bucket { + tag_name: "bananas".into(), + num_buckets: 42, + })), + }, + ], + }); + + assert_error!(err, ValidationError::RepeatedTagValue ( ref specified ) if specified == "bananas"); + + // Test a combination of [`TagValue`] and [`Bucket`] + let err = serialization::Wrapper::try_from(proto::PartitionTemplate { + parts: vec![ + proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("bananas".into())), + }, + proto::TemplatePart { + part: Some(proto::template_part::Part::Bucket(proto::Bucket { + tag_name: "bananas".into(), + num_buckets: 42, + })), + }, + ], + }); + + assert_error!(err, ValidationError::RepeatedTagValue ( ref specified ) if specified == "bananas"); + } + + /// Chrono will panic when formatting a timestamp if the "%#z" formatting + /// directive is used... + #[test] + fn test_secret_formatter_advice_panic() { + let err = serialization::Wrapper::try_from(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat("%#z".into())), + }], + }); + + assert_error!(err, ValidationError::InvalidStrftime(_)); + + // This doesn't trigger the panic, but is included for completeness. + let err = serialization::Wrapper::try_from(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat("%#Z".into())), + }], + }); + + assert_error!(err, ValidationError::InvalidStrftime(_)); + } + + #[test] + fn invalid_strftime_format_is_invalid() { + let err = serialization::Wrapper::try_from(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat("%3F".into())), + }], + }); + + assert_error!(err, ValidationError::InvalidStrftime(ref format) if format == "%3F"); + } + + #[test] + fn empty_strftime_format_is_invalid() { + let err = serialization::Wrapper::try_from(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat("".into())), + }], + }); + + assert_error!(err, ValidationError::InvalidStrftime(ref format) if format.is_empty()); + } + + /// "time" is a special column already covered by strftime, being a time + /// series database and all. + #[test] + fn time_tag_value_is_invalid() { + let err = serialization::Wrapper::try_from(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("time".into())), + }], + }); + + assert_error!(err, ValidationError::InvalidTagValue(_)); + } + + #[test] + fn empty_tag_value_is_invalid() { + let err = serialization::Wrapper::try_from(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("".into())), + }], + }); + + assert_error!(err, ValidationError::InvalidTagValue(ref value) if value.is_empty()); + } + + /// "time" is a special column already covered by strftime, being a time + /// series database and all. + #[test] + fn bucket_time_tag_name_is_invalid() { + let err = serialization::Wrapper::try_from(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::Bucket(proto::Bucket { + tag_name: "time".into(), + num_buckets: 42, + })), + }], + }); + + assert_error!(err, ValidationError::InvalidTagValue(_)); + } + + #[test] + fn bucket_empty_tag_name_is_invalid() { + let err = serialization::Wrapper::try_from(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::Bucket(proto::Bucket { + tag_name: "".into(), + num_buckets: 42, + })), + }], + }); + + assert_error!(err, ValidationError::InvalidTagValue(ref value) if value.is_empty()); + } + + #[test] + fn bucket_zero_num_buckets_is_invalid() { + let err = serialization::Wrapper::try_from(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::Bucket(proto::Bucket { + tag_name: "arán".into(), + num_buckets: 0, + })), + }], + }); + + assert_error!(err, ValidationError::InvalidNumberOfBuckets(0)); + } + + #[test] + fn bucket_too_high_num_buckets_is_invalid() { + const TOO_HIGH: u32 = 100_000; + + let err = serialization::Wrapper::try_from(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::Bucket(proto::Bucket { + tag_name: "arán".into(), + num_buckets: TOO_HIGH, + })), + }], + }); + + assert_error!(err, ValidationError::InvalidNumberOfBuckets(TOO_HIGH)); + } + + fn identity(s: &str) -> ColumnValue<'_> { + ColumnValue::Identity(s.into()) + } + + fn bucket(bucket_id: u32) -> ColumnValue<'static> { + ColumnValue::Bucket(bucket_id) + } + + fn prefix<'a, T>(s: T) -> ColumnValue<'a> + where + T: Into>, + { + ColumnValue::Prefix(s.into()) + } + + fn year(y: i32) -> ColumnValue<'static> { + ColumnValue::Datetime { + begin: Utc.with_ymd_and_hms(y, 1, 1, 0, 0, 0).unwrap(), + end: Utc.with_ymd_and_hms(y + 1, 1, 1, 0, 0, 0).unwrap(), + } + } + + #[test] + fn test_iceberg_string_hash() { + assert_eq!(iceberg_hash("iceberg"), 1210000089); + } + + // This is a test fixture designed to catch accidental changes to the + // Iceberg-like hash-bucket partitioning behaviour. + // + // You shouldn't be changing this! + #[test] + fn test_hash_bucket_fixture() { + // These are values lifted from the iceberg spark test suite for + // `BucketString`, sadly not provided in the reference/spec: + // + // https://github.com/apache/iceberg/blob/31e31fd819c846f49d2bd459b8bfadfdc3c2bc3a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkBucketFunction.java#L151-L169 + // + assert_eq!(bucket_for_tag_value("abcdefg", 5), 4); + assert_eq!(bucket_for_tag_value("abc", 128), 122); + assert_eq!(bucket_for_tag_value("abcde", 64), 54); + assert_eq!(bucket_for_tag_value("测试", 12), 8); + assert_eq!(bucket_for_tag_value("测试raul试测", 16), 1); + assert_eq!(bucket_for_tag_value("", 16), 0); + + // These are pre-existing arbitrary fixture values + assert_eq!(bucket_for_tag_value("bananas", 10), 1); + assert_eq!(bucket_for_tag_value("plátanos", 100), 98); + assert_eq!(bucket_for_tag_value("crobhaing bananaí", 1000), 166); + assert_eq!(bucket_for_tag_value("bread", 42), 9); + assert_eq!(bucket_for_tag_value("arán", 76), 72); + assert_eq!(bucket_for_tag_value("banana arán", 1337), 1284); + assert_eq!( + bucket_for_tag_value("uasmhéid bananaí", u32::MAX), + 1109892861 + ); + } + + /// Test to approximate and show how the tag value maps to the partition key + /// for the example cases in the mod-doc. The behaviour that renders the key + /// itself is a combination of this bucket assignment and the render logic. + #[test] + fn test_bucket_for_mod_doc() { + assert_eq!(bucket_for_tag_value("ananas", 10), 5); + assert_eq!(bucket_for_tag_value("!", 10), 8); + assert_eq!(bucket_for_tag_value("%50", 10), 9); + assert_eq!(bucket_for_tag_value("", 10), 0); + } + + proptest! { + #[test] + fn prop_consistent_bucketing_within_limits(tag_values in proptest::collection::vec(any::(), (1, 10)), num_buckets in any::()) { + for value in tag_values { + // First pass assign + let want_bucket = bucket_for_tag_value(&value, num_buckets); + // The assigned bucket must fit within the domain given to the bucketer. + assert!(want_bucket < num_buckets); + // Feed in the same tag value, expect the same result. + let got_bucket = bucket_for_tag_value(&value, num_buckets); + assert_eq!(want_bucket, got_bucket); + } + } + } + + /// Generate a test that asserts "partition_key" is reversible, yielding + /// "want" assuming the partition "template" was used. + macro_rules! test_build_column_values { + ( + $name:ident, + template = $template:expr, // Array/vec of TemplatePart + partition_key = $partition_key:expr, // String derived partition key + want = $want:expr // Expected build_column_values() output + ) => { + paste::paste! { + #[test] + fn []() { + let template = $template.into_iter().collect::>(); + let template = test_table_partition_override(template); + + // normalise the values into a (str, ColumnValue) for the comparison + let want = $want + .into_iter() + .collect::>(); + + let input = String::from($partition_key); + let got = build_column_values(&template, input.as_str()) + .collect::>(); + + assert_eq!(got, want); + } + } + }; + } + + test_build_column_values!( + module_doc_example_1, + template = [ + TemplatePart::TimeFormat("%Y"), + TemplatePart::TagValue("a"), + TemplatePart::TagValue("b"), + TemplatePart::Bucket("c", 10), + ], + partition_key = "2023|bananas|plátanos|5", + want = [ + (TIME_COLUMN_NAME, year(2023)), + ("a", identity("bananas")), + ("b", identity("plátanos")), + ("c", bucket(5)), + ] + ); + + test_build_column_values!( + module_doc_example_2, // Examples 2 and 3 are the same partition key + template = [ + TemplatePart::TimeFormat("%Y"), + TemplatePart::TagValue("a"), + TemplatePart::TagValue("b"), + TemplatePart::Bucket("c", 10), + ], + partition_key = "2023|!|plátanos|!", + want = [(TIME_COLUMN_NAME, year(2023)), ("b", identity("plátanos")),] + ); + + test_build_column_values!( + module_doc_example_4, + template = [ + TemplatePart::TimeFormat("%Y"), + TemplatePart::TagValue("a"), + TemplatePart::TagValue("b"), + TemplatePart::Bucket("c", 10), + ], + partition_key = "2023|!|!|!", + want = [(TIME_COLUMN_NAME, year(2023)),] + ); + + test_build_column_values!( + module_doc_example_5, + template = [ + TemplatePart::TimeFormat("%Y"), + TemplatePart::TagValue("a"), + TemplatePart::TagValue("b"), + TemplatePart::Bucket("c", 10), + ], + partition_key = "2023|cat%7Cdog|%21|8", + want = [ + (TIME_COLUMN_NAME, year(2023)), + ("a", identity("cat|dog")), + ("b", identity("!")), + ("c", bucket(8)), + ] + ); + + test_build_column_values!( + module_doc_example_6, + template = [ + TemplatePart::TimeFormat("%Y"), + TemplatePart::TagValue("a"), + TemplatePart::TagValue("b"), + TemplatePart::Bucket("c", 10), + ], + partition_key = "2023|%2550|!|9", + want = [ + (TIME_COLUMN_NAME, year(2023)), + ("a", identity("%50")), + ("c", bucket(9)), + ] + ); + + test_build_column_values!( + module_doc_example_7, + template = [ + TemplatePart::TimeFormat("%Y"), + TemplatePart::TagValue("a"), + TemplatePart::TagValue("b"), + TemplatePart::Bucket("c", 10), + ], + partition_key = "2023|^|!|0", + want = [ + (TIME_COLUMN_NAME, year(2023)), + ("a", identity("")), + ("c", bucket(0)), + ] + ); + + test_build_column_values!( + module_doc_example_8, + template = [ + TemplatePart::TimeFormat("%Y"), + TemplatePart::TagValue("a"), + TemplatePart::TagValue("b"), + TemplatePart::Bucket("c", 10), + ], + partition_key = "2023|BANANAS#|!|!|!", + want = [(TIME_COLUMN_NAME, year(2023)), ("a", prefix("BANANAS")),] + ); + + test_build_column_values!( + unicode_code_point_prefix, + template = [ + TemplatePart::TimeFormat("%Y"), + TemplatePart::TagValue("a"), + TemplatePart::TagValue("b"), + TemplatePart::Bucket("c", 10), + ], + partition_key = "2023|%28%E3%83%8E%E0%B2%A0%E7%9B%8A%E0%B2%A0%29%E3%83%8E%E5%BD%A1%E2%94%BB%E2%94%81%E2%94%BB#|!|!", + want = [ + (TIME_COLUMN_NAME, year(2023)), + ("a", prefix("(ノಠ益ಠ)ノ彡┻━┻")), + ] + ); + + test_build_column_values!( + unicode_grapheme, + template = [ + TemplatePart::TimeFormat("%Y"), + TemplatePart::TagValue("a"), + TemplatePart::TagValue("b"), + ], + partition_key = "2023|%E0%AE%A8%E0%AE%BF#|!", + want = [(TIME_COLUMN_NAME, year(2023)), ("a", prefix("நி")),] + ); + + test_build_column_values!( + unambiguous, + template = [ + TemplatePart::TimeFormat("%Y"), + TemplatePart::TagValue("a"), + TemplatePart::TagValue("b"), + ], + partition_key = "2023|is%7Cnot%21ambiguous%2510%23|!", + want = [ + (TIME_COLUMN_NAME, year(2023)), + ("a", identity("is|not!ambiguous%10#")), + ] + ); + + test_build_column_values!( + datetime_fixed, + template = [TemplatePart::TimeFormat("foo"),], + partition_key = "foo", + want = [] + ); + + test_build_column_values!( + datetime_null, + template = [TemplatePart::TimeFormat("%Y"),], + partition_key = "!", + want = [] + ); + + test_build_column_values!( + datetime_range_y, + template = [TemplatePart::TimeFormat("%Y"),], + partition_key = "2023", + want = [( + TIME_COLUMN_NAME, + ColumnValue::Datetime { + begin: Utc.with_ymd_and_hms(2023, 1, 1, 0, 0, 0).unwrap(), + end: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(), + }, + )] + ); + + test_build_column_values!( + datetime_range_y_m, + template = [TemplatePart::TimeFormat("%Y-%m"),], + partition_key = "2023-09", + want = [( + TIME_COLUMN_NAME, + ColumnValue::Datetime { + begin: Utc.with_ymd_and_hms(2023, 9, 1, 0, 0, 0).unwrap(), + end: Utc.with_ymd_and_hms(2023, 10, 1, 0, 0, 0).unwrap(), + }, + )] + ); + + test_build_column_values!( + datetime_range_y_m_overflow_year, + template = [TemplatePart::TimeFormat("%Y-%m"),], + partition_key = "2023-12", + want = [( + TIME_COLUMN_NAME, + ColumnValue::Datetime { + begin: Utc.with_ymd_and_hms(2023, 12, 1, 0, 0, 0).unwrap(), + end: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(), + }, + )] + ); + + test_build_column_values!( + datetime_range_y_m_d, + template = [TemplatePart::TimeFormat("%Y-%m-%d"),], + partition_key = "2023-09-01", + want = [( + TIME_COLUMN_NAME, + ColumnValue::Datetime { + begin: Utc.with_ymd_and_hms(2023, 9, 1, 0, 0, 0).unwrap(), + end: Utc.with_ymd_and_hms(2023, 9, 2, 0, 0, 0).unwrap(), + }, + )] + ); + + test_build_column_values!( + datetime_range_y_m_d_overflow_month, + template = [TemplatePart::TimeFormat("%Y-%m-%d"),], + partition_key = "2023-09-30", + want = [( + TIME_COLUMN_NAME, + ColumnValue::Datetime { + begin: Utc.with_ymd_and_hms(2023, 9, 30, 0, 0, 0).unwrap(), + end: Utc.with_ymd_and_hms(2023, 10, 1, 0, 0, 0).unwrap(), + }, + )] + ); + + test_build_column_values!( + datetime_range_y_m_d_overflow_year, + template = [TemplatePart::TimeFormat("%Y-%m-%d"),], + partition_key = "2023-12-31", + want = [( + TIME_COLUMN_NAME, + ColumnValue::Datetime { + begin: Utc.with_ymd_and_hms(2023, 12, 31, 0, 0, 0).unwrap(), + end: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(), + }, + )] + ); + + test_build_column_values!( + datetime_range_d_m_y, + template = [TemplatePart::TimeFormat("%d-%m-%Y"),], + partition_key = "01-09-2023", + want = [( + TIME_COLUMN_NAME, + ColumnValue::Datetime { + begin: Utc.with_ymd_and_hms(2023, 9, 1, 0, 0, 0).unwrap(), + end: Utc.with_ymd_and_hms(2023, 9, 2, 0, 0, 0).unwrap(), + }, + )] + ); + + test_build_column_values!( + bucket_part_fixture, + template = [ + TemplatePart::Bucket("a", 41), + TemplatePart::Bucket("b", 91), + TemplatePart::Bucket("c", 144) + ], + partition_key = "1|2|3", + want = [("a", bucket(1)), ("b", bucket(2)), ("c", bucket(3)),] + ); + + #[test] + #[should_panic] + fn test_build_column_values_bucket_part_out_of_range_panics() { + let template = [ + TemplatePart::Bucket("a", 42), + TemplatePart::Bucket("b", 42), + TemplatePart::Bucket("c", 42), + ] + .into_iter() + .collect::>(); + let template = test_table_partition_override(template); + + // normalise the values into a (str, ColumnValue) for the comparison + let input = String::from("1|1|43"); + let _ = build_column_values(&template, input.as_str()).collect::>(); + } + + #[test] + #[should_panic] + fn test_build_column_values_bucket_part_not_u32_panics() { + let template = [ + TemplatePart::Bucket("a", 42), + TemplatePart::Bucket("b", 42), + TemplatePart::Bucket("c", 42), + ] + .into_iter() + .collect::>(); + let template = test_table_partition_override(template); + + // normalise the values into a (str, ColumnValue) for the comparison + let input = String::from("1|1|bananas"); + let _ = build_column_values(&template, input.as_str()).collect::>(); + } + + test_build_column_values!( + datetime_not_compact_y_d, + template = [TemplatePart::TimeFormat("%Y-%d"),], + partition_key = "2023-01", + want = [] + ); + + test_build_column_values!( + datetime_not_compact_m, + template = [TemplatePart::TimeFormat("%m"),], + partition_key = "01", + want = [] + ); + + test_build_column_values!( + datetime_not_compact_d, + template = [TemplatePart::TimeFormat("%d"),], + partition_key = "01", + want = [] + ); + + test_build_column_values!( + datetime_range_unimplemented_y_m_d_h, + template = [TemplatePart::TimeFormat("%Y-%m-%dT%H"),], + partition_key = "2023-12-31T00", + want = [] + ); + + test_build_column_values!( + datetime_range_unimplemented_y_m_d_h_m, + template = [TemplatePart::TimeFormat("%Y-%m-%dT%H:%M"),], + partition_key = "2023-12-31T00:00", + want = [] + ); + + test_build_column_values!( + datetime_range_unimplemented_y_m_d_h_m_s, + template = [TemplatePart::TimeFormat("%Y-%m-%dT%H:%M:%S"),], + partition_key = "2023-12-31T00:00:00", + want = [] + ); + + test_build_column_values!( + empty_tag_only, + template = [TemplatePart::TagValue("a")], + partition_key = "!", + want = [] + ); + + #[test] + fn test_null_partition_key_char_str_equality() { + assert_eq!( + PARTITION_KEY_VALUE_NULL.to_string(), + PARTITION_KEY_VALUE_NULL_STR + ); + } + + #[test] + fn test_column_value_partial_eq() { + assert_eq!(identity("bananas"), "bananas"); + + assert_ne!(identity("bananas"), "bananas2"); + assert_ne!(identity("bananas2"), "bananas"); + + assert_ne!(prefix("bananas"), "bananas"); + assert_ne!(prefix("bananas"), "bananas2"); + assert_ne!(prefix("bananas2"), "bananas"); + } + + #[test] + fn test_column_value_is_prefix_match() { + let b = "bananas".to_string(); + assert!(identity("bananas").is_prefix_match_of(b)); + + assert!(identity("bananas").is_prefix_match_of("bananas")); + assert!(identity("bananas").is_prefix_match_of("bananas2")); + + assert!(prefix("bananas").is_prefix_match_of("bananas")); + assert!(prefix("bananas").is_prefix_match_of("bananas2")); + + assert!(!identity("bananas2").is_prefix_match_of("bananas")); + assert!(!prefix("bananas2").is_prefix_match_of("bananas")); + } + + /// This test asserts the default derived partitioning scheme with no + /// overrides. + /// + /// Changing this default during the lifetime of a cluster will cause the + /// implicit (not overridden) partition schemes to change, potentially + /// breaking the system invariant that a given primary keys maps to a + /// single partition. + /// + /// You shouldn't be changing this! + #[test] + fn test_default_template_fixture() { + let ns = NamespacePartitionTemplateOverride::default(); + let table = TablePartitionTemplateOverride::try_new(None, &ns).unwrap(); + let got = table.parts().collect::>(); + assert_matches!(got.as_slice(), [TemplatePart::TimeFormat("%Y-%m-%d")]); + } + + #[test] + fn len_of_default_template_is_1() { + let ns = NamespacePartitionTemplateOverride::default(); + let t = TablePartitionTemplateOverride::try_new(None, &ns).unwrap(); + + assert_eq!(t.len(), 1); + } + + #[test] + fn no_custom_table_template_specified_gets_namespace_template() { + let namespace_template = + NamespacePartitionTemplateOverride::try_from(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat("year-%Y".into())), + }], + }) + .unwrap(); + let table_template = + TablePartitionTemplateOverride::try_new(None, &namespace_template).unwrap(); + + assert_eq!(table_template.len(), 1); + assert_eq!(table_template.0, namespace_template.0); + } + + #[test] + fn custom_table_template_specified_ignores_namespace_template() { + let custom_table_template = proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("region".into())), + }], + }; + let namespace_template = + NamespacePartitionTemplateOverride::try_from(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat("year-%Y".into())), + }], + }) + .unwrap(); + let table_template = TablePartitionTemplateOverride::try_new( + Some(custom_table_template.clone()), + &namespace_template, + ) + .unwrap(); + + assert_eq!(table_template.len(), 1); + assert_eq!(table_template.0.unwrap().inner(), &custom_table_template); + } + + // The JSON representation of the partition template protobuf is stored in the database, so + // the encode/decode implementations need to be stable if we want to avoid having to + // migrate the values stored in the database. + + #[test] + fn proto_encode_json_stability() { + let custom_template = proto::PartitionTemplate { + parts: vec![ + proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("region".into())), + }, + proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat("year-%Y".into())), + }, + proto::TemplatePart { + part: Some(proto::template_part::Part::Bucket(proto::Bucket { + tag_name: "bananas".into(), + num_buckets: 42, + })), + }, + ], + }; + let expected_json_str = "{\"parts\":[\ + {\"tagValue\":\"region\"},\ + {\"timeFormat\":\"year-%Y\"},\ + {\"bucket\":{\"tagName\":\"bananas\",\"numBuckets\":42}}\ + ]}"; + + let namespace = NamespacePartitionTemplateOverride::try_from(custom_template).unwrap(); + let mut buf = Default::default(); + let _ = >::encode_by_ref( + &namespace, &mut buf, + ); + + fn extract_sqlite_argument_text( + argument_value: &sqlx::sqlite::SqliteArgumentValue<'_>, + ) -> String { + match argument_value { + sqlx::sqlite::SqliteArgumentValue::Text(cow) => cow.to_string(), + other => panic!("Expected Text values, got: {other:?}"), + } + } + + let namespace_json_str: String = buf.iter().map(extract_sqlite_argument_text).collect(); + assert_eq!(namespace_json_str, expected_json_str); + + let table = TablePartitionTemplateOverride::try_new(None, &namespace).unwrap(); + let mut buf = Default::default(); + let _ = >::encode_by_ref( + &table, &mut buf, + ); + let table_json_str: String = buf.iter().map(extract_sqlite_argument_text).collect(); + assert_eq!(table_json_str, expected_json_str); + assert_eq!(table.len(), 3); + } + + #[test] + fn test_template_size_reporting() { + const BASE_SIZE: usize = std::mem::size_of::() + + std::mem::size_of::(); + + let first_string = "^"; + let template = TablePartitionTemplateOverride::try_new( + Some(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue(first_string.into())), + }], + }), + &NamespacePartitionTemplateOverride::default(), + ) + .expect("failed to create table partition template "); + + assert_eq!( + template.size(), + BASE_SIZE + std::mem::size_of::() + first_string.len() + ); + + let second_string = "region"; + let template = TablePartitionTemplateOverride::try_new( + Some(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue(second_string.into())), + }], + }), + &NamespacePartitionTemplateOverride::default(), + ) + .expect("failed to create table partition template "); + + assert_eq!( + template.size(), + BASE_SIZE + std::mem::size_of::() + second_string.len() + ); + + let time_string = "year-%Y"; + let template = TablePartitionTemplateOverride::try_new( + Some(proto::PartitionTemplate { + parts: vec![ + proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue(second_string.into())), + }, + proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat(time_string.into())), + }, + ], + }), + &NamespacePartitionTemplateOverride::default(), + ) + .expect("failed to create table partition template "); + assert_eq!( + template.size(), + BASE_SIZE + + std::mem::size_of::() + + second_string.len() + + std::mem::size_of::() + + time_string.len() + ); + + let template = TablePartitionTemplateOverride::try_new( + Some(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::Bucket(proto::Bucket { + tag_name: second_string.into(), + num_buckets: 42, + })), + }], + }), + &NamespacePartitionTemplateOverride::default(), + ) + .expect("failed to create table partition template"); + assert_eq!( + template.size(), + BASE_SIZE + + std::mem::size_of::() + + second_string.len() + + std::mem::size_of::() + ); + } +} diff --git a/data_types/src/sequence_number_set.rs b/data_types/src/sequence_number_set.rs new file mode 100644 index 0000000..ff8ca41 --- /dev/null +++ b/data_types/src/sequence_number_set.rs @@ -0,0 +1,325 @@ +//! A set of [`SequenceNumber`] instances. + +use std::collections::BTreeMap; + +use crate::SequenceNumber; + +/// A space-efficient encoded set of [`SequenceNumber`]. +#[derive(Debug, Default, Clone, PartialEq)] +pub struct SequenceNumberSet(croaring::Treemap); + +impl SequenceNumberSet { + /// Add the specified [`SequenceNumber`] to the set. + pub fn add(&mut self, n: SequenceNumber) { + self.0.add(n.get() as _); + } + + /// Remove the specified [`SequenceNumber`] to the set, if present. + /// + /// This is a no-op if `n` was not part of `self`. + pub fn remove(&mut self, n: SequenceNumber) { + self.0.remove(n.get() as _); + } + + /// Add all the [`SequenceNumber`] in `other` to `self`. + /// + /// The result of this operation is the set union of both input sets. + pub fn add_set(&mut self, other: &Self) { + self.0.or_inplace(&other.0) + } + + /// Remove all the [`SequenceNumber`] in `other` from `self`. + pub fn remove_set(&mut self, other: &Self) { + self.0.andnot_inplace(&other.0) + } + + /// Reduce the memory usage of this set (trading off immediate CPU time) by + /// efficiently re-encoding the set (using run-length encoding). + pub fn run_optimise(&mut self) { + self.0.run_optimize(); + } + + /// Return true if the specified [`SequenceNumber`] has been added to + /// `self`. + pub fn contains(&self, n: SequenceNumber) -> bool { + self.0.contains(n.get() as _) + } + + /// Returns the number of [`SequenceNumber`] in this set. + pub fn len(&self) -> u64 { + self.0.cardinality() + } + + /// Return `true` if there are no [`SequenceNumber`] in this set. + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Return an iterator of all [`SequenceNumber`] in this set. + pub fn iter(&self) -> impl Iterator + '_ { + self.0.iter().map(|v| SequenceNumber::new(v as _)) + } + + /// Initialise a [`SequenceNumberSet`] that is pre-allocated to contain up + /// to `n` elements without reallocating. + pub fn with_capacity(n: u32) -> Self { + let mut map = BTreeMap::new(); + map.insert(0, croaring::Bitmap::with_container_capacity(n)); + Self(croaring::Treemap { map }) + } +} + +impl Extend for SequenceNumberSet { + fn extend>(&mut self, iter: T) { + self.0.extend(iter.into_iter().map(|v| v.get() as _)) + } +} + +impl Extend for SequenceNumberSet { + fn extend>(&mut self, iter: T) { + for new_set in iter { + self.add_set(&new_set); + } + } +} + +impl FromIterator for SequenceNumberSet { + fn from_iter>(iter: T) -> Self { + Self(iter.into_iter().map(|v| v.get() as _).collect()) + } +} + +/// Return the intersection of `self` and `other`. +pub fn intersect(a: &SequenceNumberSet, b: &SequenceNumberSet) -> SequenceNumberSet { + SequenceNumberSet(a.0.and(&b.0)) +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use proptest::{prelude::prop, proptest, strategy::Strategy}; + + use super::*; + + #[test] + fn test_set_operations() { + let mut a = SequenceNumberSet::default(); + let mut b = SequenceNumberSet::default(); + + // Add an element and check it is readable + a.add(SequenceNumber::new(1)); + assert!(a.contains(SequenceNumber::new(1))); + assert_eq!(a.len(), 1); + assert_eq!(a.iter().collect::>(), vec![SequenceNumber::new(1)]); + assert!(!a.contains(SequenceNumber::new(42))); + + // Merging an empty set into a should not change a + a.add_set(&b); + assert_eq!(a.len(), 1); + assert!(a.contains(SequenceNumber::new(1))); + + // Merging a non-empty set should add the new elements + b.add(SequenceNumber::new(2)); + a.add_set(&b); + assert_eq!(a.len(), 2); + assert!(a.contains(SequenceNumber::new(1))); + assert!(a.contains(SequenceNumber::new(2))); + + // Removing the set should return it to the pre-merged state. + a.remove_set(&b); + assert_eq!(a.len(), 1); + assert!(a.contains(SequenceNumber::new(1))); + + // Removing a non-existant element should be a NOP + a.remove(SequenceNumber::new(42)); + assert_eq!(a.len(), 1); + + // Removing the last element should result in an empty set. + a.remove(SequenceNumber::new(1)); + assert_eq!(a.len(), 0); + } + + #[test] + fn test_extend() { + let mut a = SequenceNumberSet::default(); + a.add(SequenceNumber::new(42)); + + let extend_set = [SequenceNumber::new(4), SequenceNumber::new(2)]; + + assert!(a.contains(SequenceNumber::new(42))); + assert!(!a.contains(SequenceNumber::new(4))); + assert!(!a.contains(SequenceNumber::new(2))); + + a.extend(extend_set); + + assert!(a.contains(SequenceNumber::new(42))); + assert!(a.contains(SequenceNumber::new(4))); + assert!(a.contains(SequenceNumber::new(2))); + } + + #[test] + fn test_extend_multiple_sets() { + let mut a = SequenceNumberSet::default(); + a.add(SequenceNumber::new(7)); + + let b = [SequenceNumber::new(13), SequenceNumber::new(76)]; + let c = [SequenceNumber::new(42), SequenceNumber::new(64)]; + + assert!(a.contains(SequenceNumber::new(7))); + for &num in [b, c].iter().flatten() { + assert!(!a.contains(num)); + } + + a.extend([ + SequenceNumberSet::from_iter(b), + SequenceNumberSet::from_iter(c), + ]); + assert!(a.contains(SequenceNumber::new(7))); + for &num in [b, c].iter().flatten() { + assert!(a.contains(num)); + } + } + + #[test] + fn test_collect() { + let collect_set = [SequenceNumber::new(4), SequenceNumber::new(2)]; + + let a = collect_set.into_iter().collect::(); + + assert!(!a.contains(SequenceNumber::new(42))); + assert!(a.contains(SequenceNumber::new(4))); + assert!(a.contains(SequenceNumber::new(2))); + } + + #[test] + fn test_partial_eq() { + let mut a = SequenceNumberSet::default(); + let mut b = SequenceNumberSet::default(); + + assert_eq!(a, b); + + a.add(SequenceNumber::new(42)); + assert_ne!(a, b); + + b.add(SequenceNumber::new(42)); + assert_eq!(a, b); + + b.add(SequenceNumber::new(24)); + assert_ne!(a, b); + + a.add(SequenceNumber::new(24)); + assert_eq!(a, b); + } + + #[test] + fn test_intersect() { + let a = [0, u64::MAX, 40, 41, 42, 43, 44, 45] + .into_iter() + .map(SequenceNumber::new) + .collect::(); + + let b = [1, 5, u64::MAX, 42] + .into_iter() + .map(SequenceNumber::new) + .collect::(); + + let intersection = intersect(&a, &b); + let want = [u64::MAX, 42] + .into_iter() + .map(SequenceNumber::new) + .collect::(); + + assert_eq!(intersection, want); + } + + /// Yield vec's of [`SequenceNumber`] derived from u64 values. + /// + /// This matches how the ingester allocates [`SequenceNumber`] - from a u64 + /// source. + fn sequence_number_vec() -> impl Strategy> { + prop::collection::vec(0..u64::MAX, 0..1024) + .prop_map(|vec| vec.into_iter().map(SequenceNumber::new).collect()) + } + + // The following tests compare to an order-independent HashSet, as the + // SequenceNumber uses the PartialOrd impl of the inner u64 for ordering, + // resulting in incorrect output when compared to an ordered set of cast as + // u64. + // + // https://github.com/influxdata/influxdb_iox/issues/7260 + // + // These tests also cover, collect()-ing to a SequenceNumberSet, etc. + proptest! { + /// Perform a SequenceNumberSet intersection test comparing the results + /// to the known-good stdlib HashSet intersection implementation. + #[test] + fn prop_set_intersection( + a in sequence_number_vec(), + b in sequence_number_vec() + ) { + let known_a = a.iter().cloned().collect::>(); + let known_b = b.iter().cloned().collect::>(); + let set_a = a.into_iter().collect::(); + let set_b = b.into_iter().collect::(); + + // The sets should be equal + assert_eq!(set_a.iter().collect::>(), known_a, "set a does not match"); + assert_eq!(set_b.iter().collect::>(), known_b, "set b does not match"); + + let known_intersection = known_a.intersection(&known_b).cloned().collect::>(); + let set_intersection = intersect(&set_a, &set_b).iter().collect::>(); + + // The set intersections should be equal. + assert_eq!(set_intersection, known_intersection); + } + + /// Perform a SequenceNumberSet remove_set test comparing the results to + /// the known-good stdlib HashSet difference implementation. + #[test] + fn prop_set_difference( + a in sequence_number_vec(), + b in sequence_number_vec() + ) { + let known_a = a.iter().cloned().collect::>(); + let known_b = b.iter().cloned().collect::>(); + let mut set_a = a.into_iter().collect::(); + let set_b = b.into_iter().collect::(); + + // The sets should be equal + assert_eq!(set_a.iter().collect::>(), known_a, "set a does not match"); + assert_eq!(set_b.iter().collect::>(), known_b, "set b does not match"); + + let known_a = known_a.difference(&known_b).cloned().collect::>(); + set_a.remove_set(&set_b); + let set_a = set_a.iter().collect::>(); + + // The set difference should be equal. + assert_eq!(set_a, known_a); + } + + /// Perform a SequenceNumberSet add_set test comparing the results to + /// the known-good stdlib HashSet or implementation. + #[test] + fn prop_set_add( + a in sequence_number_vec(), + b in sequence_number_vec() + ) { + let known_a = a.iter().cloned().collect::>(); + let known_b = b.iter().cloned().collect::>(); + let mut set_a = a.into_iter().collect::(); + let set_b = b.into_iter().collect::(); + + // The sets should be equal + assert_eq!(set_a.iter().collect::>(), known_a, "set a does not match"); + assert_eq!(set_b.iter().collect::>(), known_b, "set b does not match"); + + let known_a = known_a.union(&known_b).cloned().collect::>(); + set_a.add_set(&set_b); + let set_a = set_a.iter().collect::>(); + + // The sets should be equal. + assert_eq!(set_a, known_a); + } + } +} diff --git a/data_types/src/service_limits.rs b/data_types/src/service_limits.rs new file mode 100644 index 0000000..7c00b6a --- /dev/null +++ b/data_types/src/service_limits.rs @@ -0,0 +1,311 @@ +//! Types protecting production by implementing limits on customer data. + +use generated_types::influxdata::iox::namespace::{ + v1 as namespace_proto, v1::update_namespace_service_protection_limit_request::LimitUpdate, +}; +use observability_deps::tracing::*; +use std::num::NonZeroUsize; +use thiserror::Error; + +/// Definitions that apply to both MaxColumnsPerTable and MaxTables. Note that the hardcoded +/// default value specified in the macro invocation must be greater than 0 and fit in an `i32`. +macro_rules! define_service_limit { + ($type_name:ident, $default_value:expr, $documentation:expr) => { + /// $documentation + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] + pub struct $type_name(NonZeroUsize); + + impl TryFrom for $type_name { + type Error = ServiceLimitError; + + fn try_from(value: usize) -> Result { + // Even though the value is stored as a `usize`, service limits are stored as `i32` + // in the database and transferred as i32 over protobuf. So try to convert to an + // `i32` (and throw away the result) so that we know about invalid values before + // trying to use them. + if i32::try_from(value).is_err() { + return Err(ServiceLimitError::MustFitInI32); + } + + let nonzero_value = + NonZeroUsize::new(value).ok_or(ServiceLimitError::MustBeGreaterThanZero)?; + + Ok(Self(nonzero_value)) + } + } + + impl TryFrom for $type_name { + type Error = ServiceLimitError; + + fn try_from(value: u64) -> Result { + // Even though the value is stored as a `usize`, service limits are stored as `i32` + // in the database and transferred as i32 over protobuf. So try to convert to an + // `i32` (and throw away the result) so that we know about invalid values before + // trying to use them. + if i32::try_from(value).is_err() { + return Err(ServiceLimitError::MustFitInI32); + } + + let nonzero_value = usize::try_from(value) + .ok() + .and_then(NonZeroUsize::new) + .ok_or(ServiceLimitError::MustBeGreaterThanZero)?; + + Ok(Self(nonzero_value)) + } + } + + impl TryFrom for $type_name { + type Error = ServiceLimitError; + + fn try_from(value: i32) -> Result { + let nonzero_value = usize::try_from(value) + .ok() + .and_then(NonZeroUsize::new) + .ok_or(ServiceLimitError::MustBeGreaterThanZero)?; + + Ok(Self(nonzero_value)) + } + } + + #[allow(missing_docs)] + impl $type_name { + pub fn get(&self) -> usize { + self.0.get() + } + + /// For use by the database and some protobuf representations. It should not be + /// possible to construct an instance that contains a `NonZeroUsize` that won't fit in + /// an `i32`. + pub fn get_i32(&self) -> i32 { + self.0.get() as i32 + } + + /// Constant-time default for use in constructing test constants. + pub const fn const_default() -> Self { + // This is safe because the hardcoded value is not 0. + let value = unsafe { NonZeroUsize::new_unchecked($default_value) }; + + Self(value) + } + } + + impl Default for $type_name { + fn default() -> Self { + Self::const_default() + } + } + + impl std::fmt::Display for $type_name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } + } + + // Tell sqlx this is an i32 in the database. + impl sqlx::Type for $type_name + where + i32: sqlx::Type, + DB: sqlx::Database, + { + fn type_info() -> DB::TypeInfo { + >::type_info() + } + } + + impl<'q, DB> sqlx::Encode<'q, DB> for $type_name + where + DB: sqlx::Database, + i32: sqlx::Encode<'q, DB>, + { + fn encode_by_ref( + &self, + buf: &mut >::ArgumentBuffer, + ) -> sqlx::encode::IsNull { + >::encode_by_ref(&self.get_i32(), buf) + } + } + + // The database stores i32s, so there's a chance of invalid values already being stored in + // there. When deserializing those values, rather than panicking or returning an error, log + // and use the default instead. + impl<'r, DB: ::sqlx::Database> ::sqlx::decode::Decode<'r, DB> for $type_name + where + i32: sqlx::Decode<'r, DB>, + { + fn decode( + value: >::ValueRef, + ) -> ::std::result::Result< + Self, + ::std::boxed::Box< + dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync, + >, + > { + let data = >::decode(value)?; + + let data = Self::try_from(data).unwrap_or_else(|_| { + error!("database contains invalid $type_name value {data}, using default value"); + Self::default() + }); + + Ok(data) + } + } + }; +} + +define_service_limit!(MaxTables, 500, "Max tables allowed in a namespace."); +define_service_limit!( + MaxColumnsPerTable, + 200, + "Max columns per table allowed in a namespace." +); + +/// Overrides for service protection limits. +#[derive(Debug, Copy, Clone)] +pub struct NamespaceServiceProtectionLimitsOverride { + /// The maximum number of tables that can exist in this namespace + pub max_tables: Option, + /// The maximum number of columns per table in this namespace + pub max_columns_per_table: Option, +} + +impl TryFrom + for NamespaceServiceProtectionLimitsOverride +{ + type Error = ServiceLimitError; + + fn try_from(value: namespace_proto::ServiceProtectionLimits) -> Result { + let namespace_proto::ServiceProtectionLimits { + max_tables, + max_columns_per_table, + } = value; + + Ok(Self { + max_tables: max_tables.map(MaxTables::try_from).transpose()?, + max_columns_per_table: max_columns_per_table + .map(MaxColumnsPerTable::try_from) + .transpose()?, + }) + } +} + +/// Updating one, but not both, of the limits is what the UpdateNamespaceServiceProtectionLimit +/// gRPC request supports, so match that encoding on the Rust side. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ServiceLimitUpdate { + /// Requesting an update to the maximum number of tables allowed in this namespace + MaxTables(MaxTables), + /// Requesting an update to the maximum number of columns allowed in each table in this + /// namespace + MaxColumnsPerTable(MaxColumnsPerTable), +} + +/// Errors converting from raw values to the service limits +#[derive(Error, Debug, Clone, Copy)] +pub enum ServiceLimitError { + /// A negative or 0 value was specified; those aren't allowed + #[error("service limit values must be greater than 0")] + MustBeGreaterThanZero, + + /// No value was provided so we can't update anything + #[error("a supported service limit value is required")] + NoValueSpecified, + + /// Limits are stored as `i32` in the database and transferred as i32 over protobuf, so even + /// though they are stored as `usize` in Rust, the `usize` value must be less than `i32::MAX`. + #[error("service limit values must fit in a 32-bit signed integer (`i32`)")] + MustFitInI32, +} + +impl TryFrom> for ServiceLimitUpdate { + type Error = ServiceLimitError; + + fn try_from(limit_update: Option) -> Result { + match limit_update { + Some(LimitUpdate::MaxTables(n)) => { + Ok(ServiceLimitUpdate::MaxTables(MaxTables::try_from(n)?)) + } + Some(LimitUpdate::MaxColumnsPerTable(n)) => Ok(ServiceLimitUpdate::MaxColumnsPerTable( + MaxColumnsPerTable::try_from(n)?, + )), + None => Err(ServiceLimitError::NoValueSpecified), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn extract_sqlite_argument_i32(argument_value: &sqlx::sqlite::SqliteArgumentValue<'_>) -> i32 { + match argument_value { + sqlx::sqlite::SqliteArgumentValue::Int(i) => *i, + other => panic!("Expected Int values, got: {other:?}"), + } + } + + macro_rules! service_limit_test { + ($type_name:ident, $module_name: ident) => { + mod $module_name { + use super::*; + + fn success>(value: T, expected: usize) + where + >::Error: std::fmt::Debug, + { + assert_eq!(value.try_into().unwrap().get(), expected); + } + + #[test] + fn successful_conversions() { + success(1usize, 1); + success(1u64, 1); + success(1i32, 1); + success(i32::MAX, i32::MAX as usize); + } + + fn failure>(value: T, expected_error_message: &str) + where + >::Error: std::fmt::Debug + std::fmt::Display, + { + assert_eq!( + value.try_into().unwrap_err().to_string(), + expected_error_message + ); + } + + #[test] + fn failed_conversions() { + failure(0usize, "service limit values must be greater than 0"); + failure(0u64, "service limit values must be greater than 0"); + failure(0i32, "service limit values must be greater than 0"); + failure(-1i32, "service limit values must be greater than 0"); + failure( + i32::MAX as usize + 1, + "service limit values must fit in a 32-bit signed integer (`i32`)", + ); + failure( + i32::MAX as u64 + 1, + "service limit values must fit in a 32-bit signed integer (`i32`)", + ); + } + + #[test] + fn encode() { + let value = $type_name::try_from(10).unwrap(); + let mut buf = Default::default(); + let _ = <$type_name as sqlx::Encode<'_, sqlx::Sqlite>>::encode_by_ref( + &value, &mut buf, + ); + + let encoded: Vec<_> = buf.iter().map(extract_sqlite_argument_i32).collect(); + assert_eq!(encoded, &[value.get_i32()]); + } + } + }; + } + + service_limit_test!(MaxTables, max_tables); + service_limit_test!(MaxColumnsPerTable, max_columns_per_table); +} diff --git a/data_types/src/snapshot/hash.rs b/data_types/src/snapshot/hash.rs new file mode 100644 index 0000000..adf8c24 --- /dev/null +++ b/data_types/src/snapshot/hash.rs @@ -0,0 +1,219 @@ +//! A primitive hash table supporting linear probing + +use bytes::Bytes; +use generated_types::influxdata::iox::catalog_cache::v1 as generated; +use siphasher::sip::SipHasher24; + +use snafu::{ensure, Snafu}; + +/// Error for [`HashBuckets`] +#[derive(Debug, Snafu)] +#[allow(missing_docs, missing_copy_implementations)] +pub enum Error { + #[snafu(display("Bucket length not a power of two"))] + BucketsNotPower, + #[snafu(display("Unrecognized hash function"))] + UnrecognizedHash, +} + +/// Result for [`HashBuckets`] +pub type Result = std::result::Result; + +/// A primitive hash table supporting [linear probing] +/// +/// [linear probing](https://en.wikipedia.org/wiki/Linear_probing) +#[derive(Debug, Clone)] +pub struct HashBuckets { + /// The mask to yield index in `buckets` from a u64 hash + mask: usize, + /// A sequence of u32 encoding the value index + 1, or 0 if empty + buckets: Bytes, + /// The hash function to use + hash: SipHasher24, +} + +impl HashBuckets { + /// Performs a lookup of `value` + pub fn lookup(&self, value: &[u8]) -> HashProbe<'_> { + self.lookup_raw(self.hash.hash(value)) + } + + fn lookup_raw(&self, hash: u64) -> HashProbe<'_> { + let idx = (hash as usize) & self.mask; + HashProbe { + idx, + buckets: self, + mask: self.mask as _, + } + } +} + +impl TryFrom for HashBuckets { + type Error = Error; + + fn try_from(value: generated::HashBuckets) -> std::result::Result { + let buckets_len = value.buckets.len(); + ensure!(buckets_len.count_ones() == 1, BucketsNotPowerSnafu); + let mask = buckets_len.wrapping_sub(1) ^ 3; + match value.hash_function { + Some(generated::hash_buckets::HashFunction::SipHash24(s)) => Ok(Self { + mask, + buckets: value.buckets, + hash: SipHasher24::new_with_keys(s.key0, s.key1), + }), + _ => Err(Error::UnrecognizedHash), + } + } +} + +impl From for generated::HashBuckets { + fn from(value: HashBuckets) -> Self { + let (key0, key1) = value.hash.keys(); + Self { + buckets: value.buckets, + hash_function: Some(generated::hash_buckets::HashFunction::SipHash24( + generated::SipHash24 { key0, key1 }, + )), + } + } +} + +/// Yields the indices to probe for equality +#[derive(Debug)] +pub struct HashProbe<'a> { + buckets: &'a HashBuckets, + idx: usize, + mask: usize, +} + +impl<'a> Iterator for HashProbe<'a> { + type Item = usize; + + fn next(&mut self) -> Option { + let slice = self.buckets.buckets.get(self.idx..self.idx + 4)?; + let entry = u32::from_le_bytes(slice.try_into().unwrap()); + self.idx = (self.idx + 4) & self.mask; + + // Empty entries are encoded as 0 + Some(entry.checked_sub(1)? as usize) + } +} + +/// An encoder for [`HashBuckets`] +#[derive(Debug)] +pub struct HashBucketsEncoder { + mask: usize, + buckets: Vec, + hash: SipHasher24, + len: u32, + capacity: u32, +} + +impl HashBucketsEncoder { + /// Create a new [`HashBucketsEncoder`] + /// + /// # Panics + /// + /// Panics if capacity >= u32::MAX + pub fn new(capacity: usize) -> Self { + assert!(capacity < u32::MAX as usize); + + let buckets_len = (capacity * 2).next_power_of_two() * 4; + let mask = buckets_len.wrapping_sub(1) ^ 3; + Self { + mask, + len: 0, + capacity: capacity as u32, + buckets: vec![0; buckets_len], + // Note: this uses keys (0, 0) + hash: SipHasher24::new(), + } + } + + /// Append a new value + /// + /// # Panics + /// + /// Panics if this would exceed the capacity provided to new + pub fn push(&mut self, v: &[u8]) { + self.push_raw(self.hash.hash(v)); + } + + /// Append a new value by hash, returning the bucket index + fn push_raw(&mut self, hash: u64) -> usize { + assert_ne!(self.len, self.capacity); + self.len += 1; + let entry = self.len; + let mut idx = (hash as usize) & self.mask; + loop { + let s = &mut self.buckets[idx..idx + 4]; + let s: &mut [u8; 4] = s.try_into().unwrap(); + if s.iter().all(|x| *x == 0) { + *s = entry.to_le_bytes(); + return idx / 4; + } + idx = (idx + 4) & self.mask; + } + } + + /// Construct the output [`HashBuckets`] + pub fn finish(self) -> HashBuckets { + HashBuckets { + mask: self.mask, + hash: self.hash, + buckets: self.buckets.into(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_collision() { + let mut builder = HashBucketsEncoder::new(6); + + assert_eq!(builder.push_raw(14), 3); + assert_eq!(builder.push_raw(297), 10); + assert_eq!(builder.push_raw(43), 11); // Hashes to occupied bucket 10 + assert_eq!(builder.push_raw(60), 15); + assert_eq!(builder.push_raw(124), 0); // Hashes to occupied bucket 15 + assert_eq!(builder.push_raw(0), 1); // Hashes to occupied bucket 0 + + let buckets = builder.finish(); + + let l = buckets.lookup_raw(14).collect::>(); + assert_eq!(l, vec![0]); + + let l = buckets.lookup_raw(297).collect::>(); + assert_eq!(l, vec![1, 2]); + + let l = buckets.lookup_raw(43).collect::>(); + assert_eq!(l, vec![1, 2]); + + let l = buckets.lookup_raw(60).collect::>(); + assert_eq!(l, vec![3, 4, 5]); + + let l = buckets.lookup_raw(0).collect::>(); + assert_eq!(l, vec![4, 5]); + } + + #[test] + fn test_basic() { + let data = ["a", "", "bongos", "cupcakes", "bananas"]; + let mut builder = HashBucketsEncoder::new(data.len()); + for s in &data { + builder.push(s.as_bytes()); + } + let buckets = builder.finish(); + + let contains = |s: &str| -> bool { buckets.lookup(s.as_bytes()).any(|idx| data[idx] == s) }; + + assert!(contains("a")); + assert!(contains("")); + assert!(contains("bongos")); + assert!(contains("bananas")); + assert!(!contains("windows")); + } +} diff --git a/data_types/src/snapshot/list.rs b/data_types/src/snapshot/list.rs new file mode 100644 index 0000000..bd86b98 --- /dev/null +++ b/data_types/src/snapshot/list.rs @@ -0,0 +1,192 @@ +//! A list of [`Message`] supporting efficient skipping + +use bytes::Bytes; +use prost::Message; +use snafu::{ensure, Snafu}; +use std::marker::PhantomData; +use std::ops::Range; + +use generated_types::influxdata::iox::catalog_cache::v1 as generated; + +/// Error type for [`MessageList`] +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +pub enum Error { + #[snafu(context(false), display("PackedList decode error: {source}"))] + DecodeError { source: prost::DecodeError }, + + #[snafu(context(false), display("PackedList encode error: {source}"))] + EncodeError { source: prost::EncodeError }, + + #[snafu(display("Invalid MessageList offsets: {start}..{end}"))] + InvalidSlice { start: usize, end: usize }, + + #[snafu(display("MessageList slice {start}..{end} out of bounds 0..{bounds}"))] + SliceOutOfBounds { + start: usize, + end: usize, + bounds: usize, + }, +} + +/// Error type for [`MessageList`] +pub type Result = std::result::Result; + +/// A packed list of [`Message`] +/// +/// Normally protobuf encodes repeated fields by simply encoding the tag multiple times, +/// see [here](https://protobuf.dev/programming-guides/encoding/#optional). +/// +/// Unfortunately this means it is not possible to locate a value at a given index without +/// decoding all prior records. [`MessageList`] therefore provides a list encoding, inspired +/// by arrow, that provides this and is designed to be combined with [`prost`]'s support +/// for zero-copy decoding of [`Bytes`] +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct MessageList { + len: usize, + offsets: Bytes, + values: Bytes, + phantom: PhantomData, +} + +impl MessageList { + /// Encode `values` to a [`MessageList`] + pub fn encode(values: &[T]) -> Result { + let cap = (values.len() + 1) * 4; + let mut offsets: Vec = Vec::with_capacity(cap); + offsets.extend_from_slice(&0_u32.to_le_bytes()); + + let mut cap = 0; + for x in values { + cap += x.encoded_len(); + let offset = u32::try_from(cap).unwrap(); + offsets.extend_from_slice(&offset.to_le_bytes()); + } + + let mut data = Vec::with_capacity(cap); + values.iter().try_for_each(|x| x.encode(&mut data))?; + + Ok(Self { + len: values.len(), + offsets: offsets.into(), + values: data.into(), + phantom: Default::default(), + }) + } + + /// Returns true if this list is empty + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns the number of elements in this list + pub fn len(&self) -> usize { + self.len + } + + /// Returns the element at index `idx` + pub fn get(&self, idx: usize) -> Result { + let offset_start = idx * 4; + let offset_slice = &self.offsets[offset_start..offset_start + 8]; + let start = u32::from_le_bytes(offset_slice[0..4].try_into().unwrap()) as usize; + let end = u32::from_le_bytes(offset_slice[4..8].try_into().unwrap()) as usize; + + let bounds = self.values.len(); + ensure!(end >= start, InvalidSliceSnafu { start, end }); + ensure!(end <= bounds, SliceOutOfBoundsSnafu { start, end, bounds }); + + // We slice `Bytes` to preserve zero-copy + let data = self.values.slice(start..end); + Ok(T::decode(data)?) + } +} + +impl From for MessageList { + fn from(proto: generated::MessageList) -> Self { + let len = (proto.offsets.len() / 4).saturating_sub(1); + Self { + len, + offsets: proto.offsets, + values: proto.values, + phantom: Default::default(), + } + } +} + +impl From> for generated::MessageList { + fn from(value: MessageList) -> Self { + Self { + offsets: value.offsets, + values: value.values, + } + } +} + +impl IntoIterator for MessageList { + type Item = Result; + type IntoIter = MessageListIter; + + fn into_iter(self) -> Self::IntoIter { + MessageListIter { + iter: (0..self.len), + list: self, + } + } +} + +/// [`Iterator`] for [`MessageList`] +#[derive(Debug)] +pub struct MessageListIter { + iter: Range, + list: MessageList, +} + +impl Iterator for MessageListIter { + type Item = Result; + + fn next(&mut self) -> Option { + Some(self.list.get(self.iter.next()?)) + } + + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_simple() { + let strings = ["", "test", "foo", "abc", "", "skd"]; + let strings: Vec<_> = strings.into_iter().map(ToString::to_string).collect(); + + let encoded = MessageList::encode(&strings).unwrap(); + + assert_eq!(encoded.get(5).unwrap().as_str(), "skd"); + assert_eq!(encoded.get(2).unwrap().as_str(), "foo"); + assert_eq!(encoded.get(0).unwrap().as_str(), ""); + + let decoded: Vec<_> = encoded.clone().into_iter().map(Result::unwrap).collect(); + assert_eq!(strings, decoded); + + let proto = generated::MessageList::from(encoded.clone()); + let back = MessageList::::from(proto.clone()); + assert_eq!(encoded, back); + + // Invalid decode should return error not panic + let invalid = MessageList::::from(proto); + invalid.get(2).unwrap_err(); + + let strings: Vec = vec![]; + let encoded = MessageList::encode(&strings).unwrap(); + assert_eq!(encoded.len(), 0); + assert!(encoded.is_empty()); + + let proto = generated::MessageList::default(); + let encoded = MessageList::::from(proto); + assert_eq!(encoded.len(), 0); + assert!(encoded.is_empty()); + } +} diff --git a/data_types/src/snapshot/mask.rs b/data_types/src/snapshot/mask.rs new file mode 100644 index 0000000..ae9dc3b --- /dev/null +++ b/data_types/src/snapshot/mask.rs @@ -0,0 +1,71 @@ +//! A packed bitmask + +use arrow_buffer::bit_iterator::BitIndexIterator; +use arrow_buffer::bit_util::{ceil, set_bit}; +use bytes::Bytes; +use generated_types::influxdata::iox::catalog_cache::v1 as generated; + +/// A packed bitmask +#[derive(Debug, Clone)] +pub struct BitMask { + mask: Bytes, + len: usize, +} + +impl BitMask { + /// Returns an iterator of the set indices in this mask + pub fn set_indices(&self) -> BitIndexIterator<'_> { + BitIndexIterator::new(&self.mask, 0, self.len) + } +} + +impl From for BitMask { + fn from(value: generated::BitMask) -> Self { + Self { + mask: value.mask, + len: value.len as _, + } + } +} + +impl From for generated::BitMask { + fn from(value: BitMask) -> Self { + Self { + mask: value.mask, + len: value.len as _, + } + } +} + +/// A builder for [`BitMask`] +#[derive(Debug)] +pub struct BitMaskBuilder { + values: Vec, + len: usize, +} + +impl BitMaskBuilder { + /// Create a new bitmask able to store `len` boolean values + #[inline] + pub fn new(len: usize) -> Self { + Self { + values: vec![0; ceil(len, 8)], + len, + } + } + + /// Set the bit at index `idx` + #[inline] + pub fn set_bit(&mut self, idx: usize) { + set_bit(&mut self.values, idx) + } + + /// Return the built [`BitMask`] + #[inline] + pub fn finish(self) -> BitMask { + BitMask { + mask: self.values.into(), + len: self.len, + } + } +} diff --git a/data_types/src/snapshot/mod.rs b/data_types/src/snapshot/mod.rs new file mode 100644 index 0000000..7be5a93 --- /dev/null +++ b/data_types/src/snapshot/mod.rs @@ -0,0 +1,11 @@ +//! Definitions of catalog snapshots +//! +//! Snapshots are read-optimised, that is they are designed to be inexpensive to +//! decode, making extensive use of zero-copy [`Bytes`](bytes::Bytes) in place of +//! allocating structures such as `String` and `Vec` + +pub mod hash; +pub mod list; +pub mod mask; +pub mod partition; +pub mod table; diff --git a/data_types/src/snapshot/partition.rs b/data_types/src/snapshot/partition.rs new file mode 100644 index 0000000..d1838e5 --- /dev/null +++ b/data_types/src/snapshot/partition.rs @@ -0,0 +1,246 @@ +//! Snapshot definition for partitions + +use crate::snapshot::list::MessageList; +use crate::snapshot::mask::{BitMask, BitMaskBuilder}; +use crate::{ + ColumnId, ColumnSet, CompactionLevelProtoError, NamespaceId, ObjectStoreId, ParquetFile, + ParquetFileId, Partition, PartitionHashId, PartitionHashIdError, PartitionId, + SkippedCompaction, SortKeyIds, TableId, Timestamp, +}; +use bytes::Bytes; +use generated_types::influxdata::iox::{ + catalog_cache::v1 as proto, skipped_compaction::v1 as skipped_compaction_proto, +}; +use snafu::{OptionExt, ResultExt, Snafu}; + +/// Error for [`PartitionSnapshot`] +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +pub enum Error { + #[snafu(display("Error decoding PartitionFile: {source}"))] + FileDecode { + source: crate::snapshot::list::Error, + }, + + #[snafu(display("Error encoding ParquetFile: {source}"))] + FileEncode { + source: crate::snapshot::list::Error, + }, + + #[snafu(display("Missing required field {field}"))] + RequiredField { field: &'static str }, + + #[snafu(context(false))] + CompactionLevel { source: CompactionLevelProtoError }, + + #[snafu(context(false))] + PartitionHashId { source: PartitionHashIdError }, + + #[snafu(display("Invalid partition key: {source}"))] + PartitionKey { source: std::str::Utf8Error }, +} + +/// Result for [`PartitionSnapshot`] +pub type Result = std::result::Result; + +/// A snapshot of a partition +#[derive(Debug, Clone)] +pub struct PartitionSnapshot { + /// The [`NamespaceId`] + namespace_id: NamespaceId, + /// The [`TableId`] + table_id: TableId, + /// The [`PartitionId`] + partition_id: PartitionId, + /// The [`PartitionHashId`] + partition_hash_id: Option, + /// The generation of this snapshot + generation: u64, + /// The partition key + key: Bytes, + /// The files + files: MessageList, + /// The columns for this partition + columns: ColumnSet, + /// The sort key ids + sort_key: SortKeyIds, + /// The time of a new file + new_file_at: Option, + /// Skipped compaction. + skipped_compaction: Option, +} + +impl PartitionSnapshot { + /// Create a new [`PartitionSnapshot`] from the provided state + pub fn encode( + namespace_id: NamespaceId, + partition: Partition, + files: Vec, + skipped_compaction: Option, + generation: u64, + ) -> Result { + // Iterate in reverse order as schema additions are normally additive and + // so the later files will typically have more columns + let columns = files.iter().rev().fold(ColumnSet::empty(), |mut acc, v| { + acc.union(&v.column_set); + acc + }); + + let files = files + .into_iter() + .map(|file| { + let mut mask = BitMaskBuilder::new(columns.len()); + for (idx, _) in columns.intersect(&file.column_set) { + mask.set_bit(idx); + } + + proto::PartitionFile { + id: file.id.get(), + object_store_uuid: Some(file.object_store_id.get_uuid().into()), + min_time: file.min_time.0, + max_time: file.max_time.0, + file_size_bytes: file.file_size_bytes, + row_count: file.row_count, + compaction_level: file.compaction_level as _, + created_at: file.created_at.0, + max_l0_created_at: file.max_l0_created_at.0, + column_mask: Some(mask.finish().into()), + } + }) + .collect::>(); + + Ok(Self { + generation, + columns, + namespace_id, + partition_id: partition.id, + partition_hash_id: partition.hash_id().cloned(), + key: partition.partition_key.as_bytes().to_vec().into(), + files: MessageList::encode(&files).context(FileEncodeSnafu)?, + sort_key: partition.sort_key_ids().cloned().unwrap_or_default(), + table_id: partition.table_id, + new_file_at: partition.new_file_at, + skipped_compaction: skipped_compaction.map(|sc| sc.into()), + }) + } + + /// Create a new [`PartitionSnapshot`] from a `proto` and generation + pub fn decode(proto: proto::Partition, generation: u64) -> Self { + let table_id = TableId::new(proto.table_id); + let partition_hash_id = proto + .partition_hash_id + .then(|| PartitionHashId::from_raw(table_id, proto.key.as_ref())); + + Self { + generation, + table_id, + partition_hash_id, + key: proto.key, + files: MessageList::from(proto.files.unwrap_or_default()), + namespace_id: NamespaceId::new(proto.namespace_id), + partition_id: PartitionId::new(proto.partition_id), + columns: ColumnSet::new(proto.column_ids.into_iter().map(ColumnId::new)), + sort_key: SortKeyIds::new(proto.sort_key_ids.into_iter().map(ColumnId::new)), + new_file_at: proto.new_file_at.map(Timestamp::new), + skipped_compaction: proto.skipped_compaction, + } + } + + /// Returns the generation of this snapshot + pub fn generation(&self) -> u64 { + self.generation + } + + /// Returns the [`PartitionId`] + pub fn partition_id(&self) -> PartitionId { + self.partition_id + } + + /// Returns the [`PartitionHashId`] if any + pub fn partition_hash_id(&self) -> Option<&PartitionHashId> { + self.partition_hash_id.as_ref() + } + + /// Returns the file at index `idx` + pub fn file(&self, idx: usize) -> Result { + let file = self.files.get(idx).context(FileDecodeSnafu)?; + + let uuid = file.object_store_uuid.context(RequiredFieldSnafu { + field: "object_store_uuid", + })?; + + let column_set = match file.column_mask { + Some(mask) => { + let mask = BitMask::from(mask); + ColumnSet::new(mask.set_indices().map(|idx| self.columns[idx])) + } + None => self.columns.clone(), + }; + + Ok(ParquetFile { + id: ParquetFileId(file.id), + namespace_id: self.namespace_id, + table_id: self.table_id, + partition_id: self.partition_id, + partition_hash_id: self.partition_hash_id.clone(), + object_store_id: ObjectStoreId::from_uuid(uuid.into()), + min_time: Timestamp(file.min_time), + max_time: Timestamp(file.max_time), + to_delete: None, + file_size_bytes: file.file_size_bytes, + row_count: file.row_count, + compaction_level: file.compaction_level.try_into()?, + created_at: Timestamp(file.created_at), + column_set, + max_l0_created_at: Timestamp(file.max_l0_created_at), + }) + } + + /// Returns an iterator over the files in this snapshot + pub fn files(&self) -> impl Iterator> + '_ { + (0..self.files.len()).map(|idx| self.file(idx)) + } + + /// Returns the [`Partition`] for this snapshot + pub fn partition(&self) -> Result { + let key = std::str::from_utf8(&self.key).context(PartitionKeySnafu)?; + Ok(Partition::new_catalog_only( + self.partition_id, + self.partition_hash_id.clone(), + self.table_id, + key.into(), + self.sort_key.clone(), + self.new_file_at, + )) + } + + /// Returns the columns IDs + pub fn column_ids(&self) -> &ColumnSet { + &self.columns + } + + /// Return skipped compaction for this partition, if any. + pub fn skipped_compaction(&self) -> Option { + self.skipped_compaction + .as_ref() + .cloned() + .map(|sc| sc.into()) + } +} + +impl From for proto::Partition { + fn from(value: PartitionSnapshot) -> Self { + Self { + key: value.key, + files: Some(value.files.into()), + namespace_id: value.namespace_id.get(), + table_id: value.table_id.get(), + partition_id: value.partition_id.get(), + partition_hash_id: value.partition_hash_id.is_some(), + column_ids: value.columns.iter().map(|x| x.get()).collect(), + sort_key_ids: value.sort_key.iter().map(|x| x.get()).collect(), + new_file_at: value.new_file_at.map(|x| x.get()), + skipped_compaction: value.skipped_compaction, + } + } +} diff --git a/data_types/src/snapshot/table.rs b/data_types/src/snapshot/table.rs new file mode 100644 index 0000000..08c235d --- /dev/null +++ b/data_types/src/snapshot/table.rs @@ -0,0 +1,197 @@ +//! Snapshot definition for tables +use crate::snapshot::list::MessageList; +use crate::{ + Column, ColumnId, ColumnTypeProtoError, NamespaceId, Partition, PartitionId, Table, TableId, +}; +use bytes::Bytes; +use generated_types::influxdata::iox::catalog_cache::v1 as proto; +use generated_types::influxdata::iox::column_type::v1::ColumnType; +use generated_types::influxdata::iox::partition_template::v1::PartitionTemplate; +use snafu::{ResultExt, Snafu}; + +/// Error for [`TableSnapshot`] +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +pub enum Error { + #[snafu(display("Error decoding TablePartition: {source}"))] + PartitionDecode { + source: crate::snapshot::list::Error, + }, + + #[snafu(display("Error encoding TablePartition: {source}"))] + PartitionEncode { + source: crate::snapshot::list::Error, + }, + + #[snafu(display("Error decoding TableColumn: {source}"))] + ColumnDecode { + source: crate::snapshot::list::Error, + }, + + #[snafu(display("Error encoding TableColumn: {source}"))] + ColumnEncode { + source: crate::snapshot::list::Error, + }, + + #[snafu(display("Invalid column name: {source}"))] + ColumnName { source: std::str::Utf8Error }, + + #[snafu(display("Invalid table name: {source}"))] + TableName { source: std::str::Utf8Error }, + + #[snafu(display("Invalid partition template: {source}"))] + PartitionTemplate { + source: crate::partition_template::ValidationError, + }, + + #[snafu(context(false))] + ColumnType { source: ColumnTypeProtoError }, +} + +/// Result for [`TableSnapshot`] +pub type Result = std::result::Result; + +/// A snapshot of a table +#[derive(Debug, Clone)] +pub struct TableSnapshot { + table_id: TableId, + namespace_id: NamespaceId, + table_name: Bytes, + partitions: MessageList, + columns: MessageList, + partition_template: Option, + generation: u64, +} + +impl TableSnapshot { + /// Create a new [`TableSnapshot`] from the provided state + pub fn encode( + table: Table, + partitions: Vec, + columns: Vec, + generation: u64, + ) -> Result { + let columns: Vec<_> = columns + .into_iter() + .map(|c| proto::TableColumn { + id: c.id.get(), + name: c.name.into(), + column_type: ColumnType::from(c.column_type).into(), + }) + .collect(); + + let partitions: Vec<_> = partitions + .into_iter() + .map(|p| proto::TablePartition { + id: p.id.get(), + key: p.partition_key.as_bytes().to_vec().into(), + }) + .collect(); + + Ok(Self { + table_id: table.id, + namespace_id: table.namespace_id, + table_name: table.name.into(), + partitions: MessageList::encode(&partitions).context(PartitionEncodeSnafu)?, + columns: MessageList::encode(&columns).context(ColumnEncodeSnafu)?, + partition_template: table.partition_template.as_proto().cloned(), + generation, + }) + } + + /// Create a new [`TableSnapshot`] from a `proto` and generation + pub fn decode(proto: proto::Table, generation: u64) -> Self { + Self { + generation, + table_id: TableId::new(proto.table_id), + namespace_id: NamespaceId::new(proto.namespace_id), + table_name: proto.table_name, + partitions: MessageList::from(proto.partitions.unwrap_or_default()), + columns: MessageList::from(proto.columns.unwrap_or_default()), + partition_template: proto.partition_template, + } + } + + /// Returns the [`Table`] for this snapshot + pub fn table(&self) -> Result
{ + let name = std::str::from_utf8(&self.table_name).context(TableNameSnafu)?; + let template = self + .partition_template + .clone() + .try_into() + .context(PartitionTemplateSnafu)?; + + Ok(Table { + id: self.table_id, + namespace_id: self.namespace_id, + name: name.into(), + partition_template: template, + }) + } + + /// Returns the column by index + pub fn column(&self, idx: usize) -> Result { + let column = self.columns.get(idx).context(ColumnDecodeSnafu)?; + let name = std::str::from_utf8(&column.name).context(ColumnNameSnafu)?; + + Ok(Column { + id: ColumnId::new(column.id), + table_id: self.table_id, + name: name.into(), + column_type: (column.column_type as i16).try_into()?, + }) + } + + /// Returns an iterator of the columns in this table + pub fn columns(&self) -> impl Iterator> + '_ { + (0..self.columns.len()).map(|idx| self.column(idx)) + } + + /// Returns an iterator of the [`PartitionId`] in this table + pub fn partitions(&self) -> impl Iterator> + '_ { + (0..self.partitions.len()).map(|idx| { + let p = self.partitions.get(idx).context(PartitionDecodeSnafu)?; + Ok(TableSnapshotPartition { + id: PartitionId::new(p.id), + key: p.key, + }) + }) + } + + /// Returns the generation of this snapshot + pub fn generation(&self) -> u64 { + self.generation + } +} + +/// Partition information stored within [`TableSnapshot`] +#[derive(Debug)] +pub struct TableSnapshotPartition { + id: PartitionId, + key: Bytes, +} + +impl TableSnapshotPartition { + /// Returns the [`PartitionId`] for this partition + pub fn id(&self) -> PartitionId { + self.id + } + + /// Returns the partition key for this partition + pub fn key(&self) -> &[u8] { + &self.key + } +} + +impl From for proto::Table { + fn from(value: TableSnapshot) -> Self { + Self { + partitions: Some(value.partitions.into()), + columns: Some(value.columns.into()), + partition_template: value.partition_template, + namespace_id: value.namespace_id.get(), + table_id: value.table_id.get(), + table_name: value.table_name, + } + } +} diff --git a/datafusion_util/Cargo.toml b/datafusion_util/Cargo.toml new file mode 100644 index 0000000..1f5f554 --- /dev/null +++ b/datafusion_util/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "datafusion_util" +description = "Datafusion utilities" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[lints] +workspace = true + +[dependencies] +async-trait = "0.1" +datafusion = { workspace = true } +futures = "0.3" +object_store = { workspace = true } +observability_deps = { path = "../observability_deps" } +pin-project = "1.1" +schema = { path = "../schema" } +tokio = { version = "1.35", features = ["parking_lot", "sync"] } +tokio-stream = "0.1" +url = "2.5" +workspace-hack = { version = "0.1", path = "../workspace-hack" } diff --git a/datafusion_util/src/config.rs b/datafusion_util/src/config.rs new file mode 100644 index 0000000..ed41b19 --- /dev/null +++ b/datafusion_util/src/config.rs @@ -0,0 +1,50 @@ +use std::{fmt::Display, sync::Arc}; + +use datafusion::{ + config::ConfigOptions, execution::runtime_env::RuntimeEnv, prelude::SessionConfig, +}; +use object_store::ObjectStore; +use schema::TIME_DATA_TIMEZONE; +use url::Url; + +// The default catalog name - this impacts what SQL queries use if not specified +pub const DEFAULT_CATALOG: &str = "public"; +// The default schema name - this impacts what SQL queries use if not specified +pub const DEFAULT_SCHEMA: &str = "iox"; + +/// The maximum number of rows that DataFusion should create in each RecordBatch +pub const BATCH_SIZE: usize = 8 * 1024; + +/// Return a SessionConfig object configured for IOx +pub fn iox_session_config() -> SessionConfig { + // Enable parquet predicate pushdown optimization + let mut options = ConfigOptions::new(); + options.execution.parquet.pushdown_filters = true; + options.execution.parquet.reorder_filters = true; + options.execution.time_zone = TIME_DATA_TIMEZONE().map(|s| s.to_string()); + options.optimizer.repartition_sorts = true; + + SessionConfig::from(options) + .with_batch_size(BATCH_SIZE) + .with_create_default_catalog_and_schema(true) + .with_information_schema(true) + .with_default_catalog_and_schema(DEFAULT_CATALOG, DEFAULT_SCHEMA) + // Tell the datafusion optimizer to avoid repartitioning sorted inputs + .with_prefer_existing_sort(true) + // Avoid repartitioning file scans as it destroys existing sort orders + // see https://github.com/influxdata/influxdb_iox/issues/9450 + // see https://github.com/apache/arrow-datafusion/issues/8451 + .with_repartition_file_scans(false) +} + +/// Register the "IOx" object store provider for URLs of the form "iox://{id} +/// +/// Return the previous registered store, if any +pub fn register_iox_object_store( + runtime: impl AsRef, + id: D, + object_store: Arc, +) -> Option> { + let url = Url::parse(&format!("iox://{id}")).unwrap(); + runtime.as_ref().register_object_store(&url, object_store) +} diff --git a/datafusion_util/src/lib.rs b/datafusion_util/src/lib.rs new file mode 100644 index 0000000..6323f06 --- /dev/null +++ b/datafusion_util/src/lib.rs @@ -0,0 +1,519 @@ +#![deny( + clippy::future_not_send, + clippy::todo, + clippy::dbg_macro, + clippy::clone_on_ref_ptr, + rustdoc::broken_intra_doc_links, + rustdoc::bare_urls, + rust_2018_idioms, + unused_crate_dependencies +)] +#![allow(clippy::clone_on_ref_ptr)] + +//! This module contains various DataFusion utility functions. +//! +//! Almost everything for manipulating DataFusion `Expr`s IOx should be in DataFusion already +//! (or if not it should be upstreamed). +//! +//! For example, check out +//! [datafusion_optimizer::utils](https://docs.rs/datafusion-optimizer/13.0.0/datafusion_optimizer/utils/index.html) +//! for expression manipulation functions. + +use datafusion::execution::memory_pool::{MemoryPool, UnboundedMemoryPool}; +use std::collections::HashSet; +// Workaround for "unused crate" lint false positives. +use workspace_hack as _; + +pub mod config; +pub mod sender; +pub mod watch; + +use std::sync::Arc; +use std::task::{Context, Poll}; + +use datafusion::arrow::array::BooleanArray; +use datafusion::arrow::compute::filter_record_batch; +use datafusion::arrow::datatypes::{DataType, Fields}; +use datafusion::common::stats::Precision; +use datafusion::common::{DataFusionError, ToDFSchema}; +use datafusion::execution::context::TaskContext; +use datafusion::logical_expr::expr::Sort; +use datafusion::logical_expr::utils::inspect_expr_pre; +use datafusion::physical_expr::execution_props::ExecutionProps; +use datafusion::physical_expr::{create_physical_expr, PhysicalExpr}; +use datafusion::physical_optimizer::pruning::PruningPredicate; +use datafusion::physical_plan::{collect, EmptyRecordBatchStream, ExecutionPlan}; +use datafusion::prelude::{lit, Column, Expr, SessionContext}; +use datafusion::{ + arrow::{ + datatypes::{Schema, SchemaRef}, + record_batch::RecordBatch, + }, + physical_plan::{RecordBatchStream, SendableRecordBatchStream}, + scalar::ScalarValue, +}; +use futures::{Stream, StreamExt}; +use schema::TIME_DATA_TIMEZONE; +use tokio::sync::mpsc::{Receiver, UnboundedReceiver}; +use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream}; +use watch::WatchedTask; + +/// Traits to help creating DataFusion [`Expr`]s +pub trait AsExpr { + /// Creates a DataFusion expr + fn as_expr(&self) -> Expr; + + /// creates a DataFusion SortExpr + fn as_sort_expr(&self) -> Expr { + Expr::Sort(Sort { + expr: Box::new(self.as_expr()), + asc: true, // Sort ASCENDING + nulls_first: true, + }) + } +} + +impl AsExpr for Arc { + fn as_expr(&self) -> Expr { + self.as_ref().as_expr() + } +} + +impl AsExpr for str { + fn as_expr(&self) -> Expr { + // note using `col()` will parse identifiers and try to + // split them on `.`. + // + // So it would treat 'foo.bar' as table 'foo', column 'bar' + // + // This is not correct for influxrpc, so instead treat it + // like the column "foo.bar" + Expr::Column(Column { + relation: None, + name: self.into(), + }) + } +} + +impl AsExpr for Expr { + fn as_expr(&self) -> Expr { + self.clone() + } +} + +/// Creates an `Expr` that represents a Dictionary encoded string (e.g +/// the type of constant that a tag would be compared to) +pub fn lit_dict(value: &str) -> Expr { + // expr has been type coerced + lit(ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::new_utf8(value)), + )) +} + +/// Creates expression like: +/// start <= time && time < end +pub fn make_range_expr(start: i64, end: i64, time: impl AsRef) -> Expr { + // We need to cast the start and end values to timestamps + // the equivalent of: + let ts_start = timestamptz_nano(start); + let ts_end = timestamptz_nano(end); + + let time_col = time.as_ref().as_expr(); + let ts_low = lit(ts_start).lt_eq(time_col.clone()); + let ts_high = time_col.lt(lit(ts_end)); + + ts_low.and(ts_high) +} + +/// Ensures all columns referred to in `filters` are in the `projection`, if +/// any, adding them if necessary. +pub fn extend_projection_for_filters( + schema: &Schema, + filters: &[Expr], + projection: Option<&Vec>, +) -> Result>, DataFusionError> { + let Some(mut projection) = projection.cloned() else { + return Ok(None); + }; + + let mut seen_cols: HashSet = projection.iter().cloned().collect(); + for filter in filters { + inspect_expr_pre(filter, |expr| { + if let Expr::Column(c) = expr { + let idx = schema.index_of(&c.name)?; + // if haven't seen this column before, add it to the list + if seen_cols.insert(idx) { + projection.push(idx); + } + } + Ok(()) as Result<(), DataFusionError> + })?; + } + Ok(Some(projection)) +} + +// TODO port this upstream to datafusion (maybe as From
{ + let tables = repos.tables(); + + // Note the export format doesn't currently have any table level information + let table_name = iox_metadata.table_name.as_ref(); + + if let Some(table) = tables + .get_by_namespace_and_name(namespace.id, table_name) + .await? + { + return Ok(table); + } + + // use exported table + if let Some(table) = self.exported_contents.table(namespace.id.get(), table_name) { + return Ok(tables + .create( + &table.name, + table.partition_template.try_into()?, + NamespaceId::new(table.namespace_id), + ) + .await?); + } + + // need to make a new table, create the default partitioning scheme... + let partition_template = PARTITION_BY_DAY_PROTO.as_ref().clone(); + let namespace_template = NamespacePartitionTemplateOverride::try_from(partition_template)?; + let custom_table_template = None; + let partition_template = + TablePartitionTemplateOverride::try_new(custom_table_template, &namespace_template)?; + let table = tables + .create(table_name, partition_template, namespace.id) + .await?; + Ok(table) + } + + /// Create the catalog [`Partition`] into which the specified parquet + /// file shoudl be inserted. + /// + /// The sort_key and sort_key_ids of the partition should be empty when it is first created + /// because there are no columns in any parquet files to use for sorting yet. + /// The sort_key and sort_key_ids will be updated after the parquet files are created. + async fn create_partition( + &self, + repos: &mut dyn RepoCollection, + table: &Table, + partition_key: PartitionKey, + ) -> Result { + let partition = repos + .partitions() + .create_or_get(partition_key, table.id) + .await?; + + Ok(partition) + } + + /// Update sort keys of the partition + /// + /// file should be inserted. + /// + /// First attempts to use any available metadata from the + /// catalog export, and falls back to what is in the iox + /// metadata stored in the parquet file, if needed + async fn update_partition( + &self, + partition: &mut Partition, + repos: &mut dyn RepoCollection, + table: &Table, + iox_metadata: &IoxMetadata, + ) -> Result { + let partition_key = iox_metadata.partition_key.clone(); + + // Note we use the table_id embedded in the file's metadata + // from the source catalog to match the exported catlog (which + // is dfferent than the new table we just created in the + // target catalog); + let proto_partition = self + .exported_contents + .partition_metadata(iox_metadata.table_id.get(), partition_key.inner()); + + let new_sort_key_ids = if let Some(proto_partition) = proto_partition.as_ref() { + // Use the sort key from the source catalog + debug!(sort_key_ids=?proto_partition.sort_key_ids, "Using sort key from catalog export"); + let new_sort_key_ids = match &proto_partition.sort_key_ids { + Some(sort_key_ids) => sort_key_ids.array_sort_key_ids.clone(), + None => vec![], + }; + + SortKeyIds::from(new_sort_key_ids) + } else { + warn!("Could not find sort key in catalog metadata export, falling back to embedded metadata"); + let sort_key = iox_metadata + .sort_key + .as_ref() + .ok_or_else(|| Error::NoSortKey)?; + + let new_sort_key = sort_key.to_columns().collect::>(); + + // fetch table columns + let columns = get_table_columns_by_id(table.id, repos).await?; + columns.ids_for_names(&new_sort_key) + }; + + loop { + let res = repos + .partitions() + .cas_sort_key(partition.id, partition.sort_key_ids(), &new_sort_key_ids) + .await; + + match res { + Ok(partition) => return Ok(partition), + Err(CasFailure::ValueMismatch(_)) => { + debug!("Value mismatch when setting sort key, retrying..."); + continue; + } + Err(CasFailure::QueryError(e)) => return Err(Error::SetSortKey(e)), + } + } + } + + /// Return a [`ParquetFileParams`] (information needed to insert + /// the data into the target catalog). + /// + /// First attempts to use any available metadata from the + /// catalog export, and falls back to what is in the iox + /// metadata stored in the parquet file, if needed + #[allow(clippy::too_many_arguments)] + async fn parquet_file_params( + &self, + repos: &mut dyn RepoCollection, + namespace: &Namespace, + table: &Table, + partition: &Partition, + // parquet metadata, if known + parquet_metadata: Option, + iox_metadata: &IoxMetadata, + decoded_iox_parquet_metadata: &DecodedIoxParquetMetaData, + file_size_bytes: usize, + ) -> Result { + let object_store_id = iox_metadata.object_store_id; + + // need to make columns in the target catalog + let column_set = insert_columns(table.id, decoded_iox_parquet_metadata, repos).await?; + + let params = if let Some(proto_parquet_file) = &parquet_metadata { + let compaction_level = proto_parquet_file.compaction_level.try_into()?; + + ParquetFileParams { + namespace_id: namespace.id, + table_id: table.id, + partition_id: partition.id, + partition_hash_id: partition.hash_id().cloned(), + object_store_id, + min_time: Timestamp::new(proto_parquet_file.min_time), + max_time: Timestamp::new(proto_parquet_file.max_time), + file_size_bytes: proto_parquet_file.file_size_bytes, + row_count: proto_parquet_file.row_count, + compaction_level, + created_at: Timestamp::new(proto_parquet_file.created_at), + column_set, + max_l0_created_at: Timestamp::new(proto_parquet_file.max_l0_created_at), + } + } else { + warn!("Could not read parquet file metadata, reconstructing based on encoded metadata"); + + let (min_time, max_time) = get_min_max_times(decoded_iox_parquet_metadata)?; + let created_at = Timestamp::new(iox_metadata.creation_timestamp.timestamp_nanos()); + ParquetFileParams { + namespace_id: namespace.id, + table_id: table.id, + partition_id: partition.id, + partition_hash_id: partition.hash_id().cloned(), + object_store_id, + min_time, + max_time, + // use unwrap: if we can't fit the file size or row + // counts into usize, something is very wrong and we + // should stop immediately (and get an exact stack trace) + file_size_bytes: file_size_bytes.try_into().unwrap(), + row_count: decoded_iox_parquet_metadata.row_count().try_into().unwrap(), + //compaction_level: CompactionLevel::Final, + compaction_level: CompactionLevel::Initial, + created_at, + column_set, + max_l0_created_at: created_at, + } + }; + debug!(?params, "Created ParquetFileParams"); + Ok(params) + } +} +/// Returns a `ColumnSet` that represents all the columns specified in +/// `decoded_iox_parquet_metadata`. +/// +/// Insert the appropriate column entries in the catalog they are not +/// already present. +async fn insert_columns( + table_id: TableId, + decoded_iox_parquet_metadata: &DecodedIoxParquetMetaData, + repos: &mut dyn RepoCollection, +) -> Result { + let schema = decoded_iox_parquet_metadata.read_schema()?; + + let mut column_ids = vec![]; + + for (iox_column_type, field) in schema.iter() { + let column_name = field.name(); + let column_type = ColumnType::from(iox_column_type); + + let column = repos + .columns() + .create_or_get(column_name, table_id, column_type) + .await?; + column_ids.push(column.id); + } + + Ok(ColumnSet::new(column_ids)) +} + +/// Reads out the min and max value for the decoded_iox_parquet_metadata column +fn get_min_max_times( + decoded_iox_parquet_metadata: &DecodedIoxParquetMetaData, +) -> Result<(Timestamp, Timestamp)> { + let schema = decoded_iox_parquet_metadata.read_schema()?; + let stats = decoded_iox_parquet_metadata.read_statistics(&schema)?; + + let Some(summary) = stats.iter().find(|s| s.name == schema::TIME_COLUMN_NAME) else { + return Err(Error::BadStats { stats: None }); + }; + + let Statistics::I64(stats) = &summary.stats else { + return Err(Error::BadStats { + stats: Some(summary.stats.clone()), + }); + }; + + let (Some(min), Some(max)) = (stats.min, stats.max) else { + return Err(Error::NoMinMax { + min: stats.min, + max: stats.max, + }); + }; + + Ok((Timestamp::new(min), Timestamp::new(max))) +} + +/// Given a filename of the store parquet metadata, returns the object_store_id +/// +/// For example, `e65790df-3e42-0094-048f-0b69a7ee402c.parquet`, +/// returns `e65790df-3e42-0094-048f-0b69a7ee402c` +/// +/// For some reason the object store id embedded in the parquet file's +/// [`IoxMetadata`] and the of the actual file in object storage are +/// different, so we need to use the object_store_id actually used in +/// the source system, which is embedded in the filename +fn object_store_id_from_parquet_filename(path: &Path) -> Option { + let stem = path + // .partition_id.parquet --> .partition_id + .file_stem()? + .to_string_lossy(); + + Some(stem.to_string()) +} diff --git a/import_export/src/file/mod.rs b/import_export/src/file/mod.rs new file mode 100644 index 0000000..d0c9b1d --- /dev/null +++ b/import_export/src/file/mod.rs @@ -0,0 +1,6 @@ +/// Code to import/export files +mod export; +mod import; + +pub use export::{ExportError, RemoteExporter}; +pub use import::{Error, ExportedContents, RemoteImporter}; diff --git a/import_export/src/lib.rs b/import_export/src/lib.rs new file mode 100644 index 0000000..df7c5ef --- /dev/null +++ b/import_export/src/lib.rs @@ -0,0 +1,17 @@ +//! Import/export utilities for IOx + +#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)] +#![warn( + missing_debug_implementations, + clippy::explicit_iter_loop, + clippy::use_self, + clippy::clone_on_ref_ptr, + clippy::future_not_send, + unused_crate_dependencies +)] + +// Workaround for "unused crate" lint false positives. +use workspace_hack as _; + +/// Import/Export data to files +pub mod file; diff --git a/influxdb2_client/Cargo.toml b/influxdb2_client/Cargo.toml new file mode 100644 index 0000000..3b05518 --- /dev/null +++ b/influxdb2_client/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "influxdb2_client" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[lints] +workspace = true + +[dependencies] # In alphabetical order +bytes = "1.5" +futures = { version = "0.3", default-features = false } +reqwest = { version = "0.11", default-features = false, features = ["stream", "json", "rustls-tls-native-roots"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0.111" +snafu = "0.8" +url = "2.5.0" +uuid = { version = "1", features = ["v4"] } + +[dev-dependencies] # In alphabetical order +mockito = { version ="1.2", default-features = false } +once_cell = { version = "1.19", features = ["parking_lot"] } +parking_lot = "0.12" +tokio = { version = "1.35", features = ["macros", "parking_lot", "rt-multi-thread", "sync", "time"] } +test_helpers = { path = "../test_helpers" } diff --git a/influxdb2_client/README.md b/influxdb2_client/README.md new file mode 100644 index 0000000..7e94127 --- /dev/null +++ b/influxdb2_client/README.md @@ -0,0 +1,23 @@ +# InfluxDB V2 Client API + +This crate contains a work-in-progress implementation of a Rust client for the [InfluxDB 2.0 API](https://docs.influxdata.com/influxdb/v2.0/reference/api/). + +This client is not the Rust client for IOx. You can find that [here](../influxdb_iox_client). + +The InfluxDB IOx project plans to focus its efforts on the subset of the API which are most relevant to IOx, but we accept (welcome!) PRs for adding the other pieces of functionality. + + +## Design Notes + +When it makes sense, this client aims to mirror the [InfluxDB 2.x Go client API](https://github.com/influxdata/influxdb-client-go) + +## Contributing + +If you would like to contribute code you can do through GitHub by forking the repository and sending a pull request into the master branch. + + +## Future work + +- [ ] Publish as a crate on [crates.io](http://crates.io) + +If you would like to contribute code you can do through GitHub by forking the repository and sending a pull request into the main branch. diff --git a/influxdb2_client/examples/health.rs b/influxdb2_client/examples/health.rs new file mode 100644 index 0000000..72d0843 --- /dev/null +++ b/influxdb2_client/examples/health.rs @@ -0,0 +1,11 @@ +#[tokio::main] +async fn main() -> Result<(), Box> { + let influx_url = "some-url"; + let token = "some-token"; + + let client = influxdb2_client::Client::new(influx_url, token); + + println!("{:?}", client.health().await?); + + Ok(()) +} diff --git a/influxdb2_client/examples/label.rs b/influxdb2_client/examples/label.rs new file mode 100644 index 0000000..f7a7134 --- /dev/null +++ b/influxdb2_client/examples/label.rs @@ -0,0 +1,29 @@ +use std::collections::HashMap; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let influx_url = "http://localhost:8888"; + let token = "some-token"; + + let client = influxdb2_client::Client::new(influx_url, token); + + println!("{:?}", client.labels().await?); + println!("{:?}", client.labels_by_org("some-org_id").await?); + println!("{:?}", client.find_label("some-label_id").await?); + let mut properties = HashMap::new(); + properties.insert("some-key".to_string(), "some-value".to_string()); + println!( + "{:?}", + client + .create_label("some-org_id", "some-name", Some(properties)) + .await? + ); + println!( + "{:?}", + client + .update_label(Some("some-name".to_string()), None, "some-label_id") + .await? + ); + println!("{:?}", client.delete_label("some-label_id").await?); + Ok(()) +} diff --git a/influxdb2_client/examples/query.rs b/influxdb2_client/examples/query.rs new file mode 100644 index 0000000..00bb761 --- /dev/null +++ b/influxdb2_client/examples/query.rs @@ -0,0 +1,26 @@ +use influxdb2_client::models::{LanguageRequest, Query}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let influx_url = "http://localhost:8086"; + let token = "some-token"; + + let client = influxdb2_client::Client::new(influx_url, token); + + client.query_suggestions().await?; + client.query_suggestions_name("some-name").await?; + + client + .query_raw("some-org", Some(Query::new("some-query".to_string()))) + .await?; + + client + .query_analyze(Some(Query::new("some-query".to_string()))) + .await?; + + client + .query_ast(Some(LanguageRequest::new("some-query".to_string()))) + .await?; + + Ok(()) +} diff --git a/influxdb2_client/examples/ready.rs b/influxdb2_client/examples/ready.rs new file mode 100644 index 0000000..07d69c4 --- /dev/null +++ b/influxdb2_client/examples/ready.rs @@ -0,0 +1,11 @@ +#[tokio::main] +async fn main() -> Result<(), Box> { + let influx_url = "some-url"; + let token = "some-token"; + + let client = influxdb2_client::Client::new(influx_url, token); + + println!("{:?}", client.ready().await?); + + Ok(()) +} diff --git a/influxdb2_client/examples/setup.rs b/influxdb2_client/examples/setup.rs new file mode 100644 index 0000000..c54b12f --- /dev/null +++ b/influxdb2_client/examples/setup.rs @@ -0,0 +1,32 @@ +#[tokio::main] +async fn main() -> Result<(), Box> { + let influx_url = "http://localhost:8888"; + let token = "some-token"; + + let client = influxdb2_client::Client::new(influx_url, token); + + if client.is_onboarding_allowed().await? { + println!( + "{:?}", + client + .onboarding("some-user", "some-org", "some-bucket", None, None, None,) + .await? + ); + } + + println!( + "{:?}", + client + .post_setup_user( + "some-new-user", + "some-new-org", + "some-new-bucket", + None, + None, + None, + ) + .await? + ); + + Ok(()) +} diff --git a/influxdb2_client/examples/write.rs b/influxdb2_client/examples/write.rs new file mode 100644 index 0000000..2bbc23d --- /dev/null +++ b/influxdb2_client/examples/write.rs @@ -0,0 +1,27 @@ +use futures::prelude::*; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let org = "myorg"; + let bucket = "mybucket"; + let influx_url = "http://localhost:9999"; + let token = "my-token"; + + let client = influxdb2_client::Client::new(influx_url, token); + + let points = vec![ + influxdb2_client::models::DataPoint::builder("cpu_load_short") + .tag("host", "server01") + .tag("region", "us-west") + .field("value", 0.64) + .build()?, + influxdb2_client::models::DataPoint::builder("cpu_load_short") + .tag("host", "server01") + .field("value", 27.99) + .build()?, + ]; + + client.write(org, bucket, stream::iter(points)).await?; + + Ok(()) +} diff --git a/influxdb2_client/src/api/buckets.rs b/influxdb2_client/src/api/buckets.rs new file mode 100644 index 0000000..92065d0 --- /dev/null +++ b/influxdb2_client/src/api/buckets.rs @@ -0,0 +1,68 @@ +//! Buckets API + +use crate::models::PostBucketRequest; +use crate::{Client, HttpSnafu, RequestError, ReqwestProcessingSnafu, SerializingSnafu}; +use reqwest::Method; +use snafu::ResultExt; + +impl Client { + /// Create a new bucket in the organization specified by the 16-digit + /// hexadecimal `org_id` and with the bucket name `bucket`. + pub async fn create_bucket( + &self, + post_bucket_request: Option, + ) -> Result<(), RequestError> { + let create_bucket_url = format!("{}/api/v2/buckets", self.url); + + let response = self + .request(Method::POST, &create_bucket_url) + .header("Content-Type", "application/json") + .body( + serde_json::to_string(&post_bucket_request.unwrap_or_default()) + .context(SerializingSnafu)?, + ) + .send() + .await + .context(ReqwestProcessingSnafu)?; + + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await.context(ReqwestProcessingSnafu)?; + HttpSnafu { status, text }.fail()?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use mockito::Server; + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn create_bucket() { + let org_id = "0000111100001111".to_string(); + let bucket = "some-bucket".to_string(); + let token = "some-token"; + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock("POST", "/api/v2/buckets") + .match_header("Authorization", format!("Token {token}").as_str()) + .match_header("Content-Type", "application/json") + .match_body( + format!(r#"{{"orgID":"{org_id}","name":"{bucket}","retentionRules":[]}}"#).as_str(), + ) + .create_async() + .await; + + let client = Client::new(mock_server.url(), token); + + let _result = client + .create_bucket(Some(PostBucketRequest::new(org_id, bucket))) + .await; + + mock.assert_async().await; + } +} diff --git a/influxdb2_client/src/api/health.rs b/influxdb2_client/src/api/health.rs new file mode 100644 index 0000000..4364ec6 --- /dev/null +++ b/influxdb2_client/src/api/health.rs @@ -0,0 +1,53 @@ +//! Health +//! +//! Get health of an InfluxDB instance + +use crate::models::HealthCheck; +use crate::{Client, HttpSnafu, RequestError, ReqwestProcessingSnafu}; +use reqwest::{Method, StatusCode}; +use snafu::ResultExt; + +impl Client { + /// Get health of an instance + pub async fn health(&self) -> Result { + let health_url = format!("{}/health", self.url); + let response = self + .request(Method::GET, &health_url) + .send() + .await + .context(ReqwestProcessingSnafu)?; + + match response.status() { + StatusCode::OK => Ok(response + .json::() + .await + .context(ReqwestProcessingSnafu)?), + StatusCode::SERVICE_UNAVAILABLE => Ok(response + .json::() + .await + .context(ReqwestProcessingSnafu)?), + status => { + let text = response.text().await.context(ReqwestProcessingSnafu)?; + HttpSnafu { status, text }.fail()? + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use mockito::Server; + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn health() { + let mut mock_server = Server::new_async().await; + let mock = mock_server.mock("GET", "/health").create_async().await; + + let client = Client::new(mock_server.url(), ""); + + let _result = client.health().await; + + mock.assert_async().await; + } +} diff --git a/influxdb2_client/src/api/label.rs b/influxdb2_client/src/api/label.rs new file mode 100644 index 0000000..f71dc56 --- /dev/null +++ b/influxdb2_client/src/api/label.rs @@ -0,0 +1,318 @@ +//! Labels + +use crate::models::{LabelCreateRequest, LabelResponse, LabelUpdate, LabelsResponse}; +use crate::{Client, HttpSnafu, RequestError, ReqwestProcessingSnafu, SerializingSnafu}; +use reqwest::{Method, StatusCode}; +use snafu::ResultExt; +use std::collections::HashMap; + +impl Client { + /// List all Labels + pub async fn labels(&self) -> Result { + self.get_labels(None).await + } + + /// List all Labels by organization ID + pub async fn labels_by_org(&self, org_id: &str) -> Result { + self.get_labels(Some(org_id)).await + } + + async fn get_labels(&self, org_id: Option<&str>) -> Result { + let labels_url = format!("{}/api/v2/labels", self.url); + let mut request = self.request(Method::GET, &labels_url); + + if let Some(id) = org_id { + request = request.query(&[("orgID", id)]); + } + + let response = request.send().await.context(ReqwestProcessingSnafu)?; + match response.status() { + StatusCode::OK => Ok(response + .json::() + .await + .context(ReqwestProcessingSnafu)?), + status => { + let text = response.text().await.context(ReqwestProcessingSnafu)?; + HttpSnafu { status, text }.fail()? + } + } + } + + /// Retrieve a label by ID + pub async fn find_label(&self, label_id: &str) -> Result { + let labels_by_id_url = format!("{}/api/v2/labels/{}", self.url, label_id); + let response = self + .request(Method::GET, &labels_by_id_url) + .send() + .await + .context(ReqwestProcessingSnafu)?; + match response.status() { + StatusCode::OK => Ok(response + .json::() + .await + .context(ReqwestProcessingSnafu)?), + status => { + let text = response.text().await.context(ReqwestProcessingSnafu)?; + HttpSnafu { status, text }.fail()? + } + } + } + + /// Create a Label + pub async fn create_label( + &self, + org_id: &str, + name: &str, + properties: Option>, + ) -> Result { + let create_label_url = format!("{}/api/v2/labels", self.url); + let body = LabelCreateRequest { + org_id: org_id.into(), + name: name.into(), + properties, + }; + let response = self + .request(Method::POST, &create_label_url) + .header("Content-Type", "application/json") + .body(serde_json::to_string(&body).context(SerializingSnafu)?) + .send() + .await + .context(ReqwestProcessingSnafu)?; + match response.status() { + StatusCode::CREATED => Ok(response + .json::() + .await + .context(ReqwestProcessingSnafu)?), + status => { + let text = response.text().await.context(ReqwestProcessingSnafu)?; + HttpSnafu { status, text }.fail()? + } + } + } + + /// Update a Label + pub async fn update_label( + &self, + name: Option, + properties: Option>, + label_id: &str, + ) -> Result { + let update_label_url = format!("{}/api/v2/labels/{}", &self.url, label_id); + let body = LabelUpdate { name, properties }; + let response = self + .request(Method::PATCH, &update_label_url) + .header("Content-Type", "application/json") + .body(serde_json::to_string(&body).context(SerializingSnafu)?) + .send() + .await + .context(ReqwestProcessingSnafu)?; + match response.status() { + StatusCode::OK => Ok(response + .json::() + .await + .context(ReqwestProcessingSnafu)?), + status => { + let text = response.text().await.context(ReqwestProcessingSnafu)?; + HttpSnafu { status, text }.fail()? + } + } + } + + /// Delete a Label + pub async fn delete_label(&self, label_id: &str) -> Result<(), RequestError> { + let delete_label_url = format!("{}/api/v2/labels/{}", &self.url, label_id); + let response = self + .request(Method::DELETE, &delete_label_url) + .send() + .await + .context(ReqwestProcessingSnafu)?; + match response.status() { + StatusCode::NO_CONTENT => Ok(()), + status => { + let text = response.text().await.context(ReqwestProcessingSnafu)?; + HttpSnafu { status, text }.fail()? + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use mockito::Server; + + const BASE_PATH: &str = "/api/v2/labels"; + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn labels() { + let token = "some-token"; + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock("GET", BASE_PATH) + .match_header("Authorization", format!("Token {token}").as_str()) + .create_async() + .await; + + let client = Client::new(mock_server.url(), token); + + let _result = client.labels().await; + + mock.assert_async().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn labels_by_org() { + let token = "some-token"; + let org_id = "some-org_id"; + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock("GET", format!("{BASE_PATH}?orgID={org_id}").as_str()) + .match_header("Authorization", format!("Token {token}").as_str()) + .create_async() + .await; + + let client = Client::new(mock_server.url(), token); + + let _result = client.labels_by_org(org_id).await; + + mock.assert_async().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn find_label() { + let token = "some-token"; + let label_id = "some-id"; + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock("GET", format!("{BASE_PATH}/{label_id}").as_str()) + .match_header("Authorization", format!("Token {token}").as_str()) + .create_async() + .await; + + let client = Client::new(mock_server.url(), token); + + let _result = client.find_label(label_id).await; + + mock.assert_async().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn create_label() { + let token = "some-token"; + let org_id = "some-org"; + let name = "some-user"; + let mut properties = HashMap::new(); + properties.insert("some-key".to_string(), "some-value".to_string()); + + let mut mock_server = Server::new_async().await; + let mock = mock_server.mock("POST", BASE_PATH) + .match_header("Authorization", format!("Token {token}").as_str()) + .match_header("Content-Type", "application/json") + .match_body( + format!( + r#"{{"orgID":"{org_id}","name":"{name}","properties":{{"some-key":"some-value"}}}}"# + ) + .as_str(), + ) + .create_async().await; + + let client = Client::new(mock_server.url(), token); + + let _result = client.create_label(org_id, name, Some(properties)).await; + + mock.assert_async().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn create_label_opt() { + let token = "some-token"; + let org_id = "some-org_id"; + let name = "some-user"; + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock("POST", BASE_PATH) + .match_header("Authorization", format!("Token {token}").as_str()) + .match_header("Content-Type", "application/json") + .match_body(format!(r#"{{"orgID":"{org_id}","name":"{name}"}}"#).as_str()) + .create_async() + .await; + + let client = Client::new(mock_server.url(), token); + + let _result = client.create_label(org_id, name, None).await; + + mock.assert_async().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn update_label() { + let token = "some-token"; + let name = "some-user"; + let label_id = "some-label_id"; + let mut properties = HashMap::new(); + properties.insert("some-key".to_string(), "some-value".to_string()); + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock("PATCH", format!("{BASE_PATH}/{label_id}").as_str()) + .match_header("Authorization", format!("Token {token}").as_str()) + .match_header("Content-Type", "application/json") + .match_body( + format!(r#"{{"name":"{name}","properties":{{"some-key":"some-value"}}}}"#).as_str(), + ) + .create_async() + .await; + + let client = Client::new(mock_server.url(), token); + + let _result = client + .update_label(Some(name.to_string()), Some(properties), label_id) + .await; + + mock.assert_async().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn update_label_opt() { + let token = "some-token"; + let label_id = "some-label_id"; + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock("PATCH", format!("{BASE_PATH}/{label_id}").as_str()) + .match_header("Authorization", format!("Token {token}").as_str()) + .match_header("Content-Type", "application/json") + .match_body("{}") + .create_async() + .await; + + let client = Client::new(mock_server.url(), token); + + let _result = client.update_label(None, None, label_id).await; + + mock.assert_async().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn delete_label() { + let token = "some-token"; + let label_id = "some-label_id"; + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock("DELETE", format!("{BASE_PATH}/{label_id}").as_str()) + .match_header("Authorization", format!("Token {token}").as_str()) + .create_async() + .await; + + let client = Client::new(mock_server.url(), token); + + let _result = client.delete_label(label_id).await; + + mock.assert_async().await; + } +} diff --git a/influxdb2_client/src/api/mod.rs b/influxdb2_client/src/api/mod.rs new file mode 100644 index 0000000..ece3d9e --- /dev/null +++ b/influxdb2_client/src/api/mod.rs @@ -0,0 +1,8 @@ +//! InfluxDB v2.0 Client API +pub mod buckets; +pub mod health; +pub mod label; +pub mod query; +pub mod ready; +pub mod setup; +pub mod write; diff --git a/influxdb2_client/src/api/query.rs b/influxdb2_client/src/api/query.rs new file mode 100644 index 0000000..85e8ee6 --- /dev/null +++ b/influxdb2_client/src/api/query.rs @@ -0,0 +1,384 @@ +//! Query +//! +//! Query InfluxDB using InfluxQL or Flux Query + +use crate::{ + Client, HttpSnafu, RequestError, ReqwestProcessingSnafu, ResponseBytesSnafu, + ResponseStringSnafu, SerializingSnafu, +}; +use reqwest::{Method, StatusCode}; +use snafu::ResultExt; + +use crate::models::{ + AnalyzeQueryResponse, AstResponse, FluxSuggestion, FluxSuggestions, LanguageRequest, Query, +}; + +impl Client { + /// Get Query Suggestions + pub async fn query_suggestions(&self) -> Result { + let req_url = format!("{}/api/v2/query/suggestions", self.url); + let response = self + .request(Method::GET, &req_url) + .send() + .await + .context(ReqwestProcessingSnafu)?; + + match response.status() { + StatusCode::OK => Ok(response + .json::() + .await + .context(ReqwestProcessingSnafu)?), + status => { + let text = response.text().await.context(ReqwestProcessingSnafu)?; + HttpSnafu { status, text }.fail()? + } + } + } + + /// Query Suggestions with name + pub async fn query_suggestions_name(&self, name: &str) -> Result { + let req_url = format!( + "{}/api/v2/query/suggestions/{name}", + self.url, + name = crate::common::urlencode(name), + ); + + let response = self + .request(Method::GET, &req_url) + .send() + .await + .context(ReqwestProcessingSnafu)?; + + match response.status() { + StatusCode::OK => Ok(response + .json::() + .await + .context(ReqwestProcessingSnafu)?), + status => { + let text = response.text().await.context(ReqwestProcessingSnafu)?; + HttpSnafu { status, text }.fail()? + } + } + } + + /// Query and return the raw string data from the server + pub async fn query_raw(&self, org: &str, query: Option) -> Result { + let req_url = format!("{}/api/v2/query", self.url); + + let response = self + .request(Method::POST, &req_url) + .header("Accepting-Encoding", "identity") + .header("Content-Type", "application/json") + .query(&[("org", &org)]) + .body(serde_json::to_string(&query.unwrap_or_default()).context(SerializingSnafu)?) + .send() + .await + .context(ReqwestProcessingSnafu)?; + + match response.status() { + StatusCode::OK => { + let bytes = response.bytes().await.context(ResponseBytesSnafu)?; + String::from_utf8(bytes.to_vec()).context(ResponseStringSnafu) + } + status => { + let text = response.text().await.context(ReqwestProcessingSnafu)?; + HttpSnafu { status, text }.fail()? + } + } + } + + /// Analyze Query + pub async fn query_analyze( + &self, + query: Option, + ) -> Result { + let req_url = format!("{}/api/v2/query/analyze", self.url); + + let response = self + .request(Method::POST, &req_url) + .header("Content-Type", "application/json") + .body(serde_json::to_string(&query.unwrap_or_default()).context(SerializingSnafu)?) + .send() + .await + .context(ReqwestProcessingSnafu)?; + + match response.status() { + StatusCode::OK => Ok(response + .json::() + .await + .context(ReqwestProcessingSnafu)?), + status => { + let text = response.text().await.context(ReqwestProcessingSnafu)?; + HttpSnafu { status, text }.fail()? + } + } + } + + /// Get Query AST Repsonse + pub async fn query_ast( + &self, + language_request: Option, + ) -> Result { + let req_url = format!("{}/api/v2/query/ast", self.url); + + let response = self + .request(Method::POST, &req_url) + .header("Content-Type", "application/json") + .body( + serde_json::to_string(&language_request.unwrap_or_default()) + .context(SerializingSnafu)?, + ) + .send() + .await + .context(ReqwestProcessingSnafu)?; + + match response.status() { + StatusCode::OK => Ok(response + .json::() + .await + .context(ReqwestProcessingSnafu)?), + status => { + let text = response.text().await.context(ReqwestProcessingSnafu)?; + HttpSnafu { status, text }.fail()? + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use mockito::{Matcher, Server}; + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn query_suggestions() { + let token = "some-token"; + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock("GET", "/api/v2/query/suggestions") + .match_header("Authorization", format!("Token {token}").as_str()) + .create_async() + .await; + + let client = Client::new(mock_server.url(), token); + + let _result = client.query_suggestions().await; + + mock.assert_async().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn query_suggestions_name() { + let token = "some-token"; + let suggestion_name = "some-name"; + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock( + "GET", + format!( + "/api/v2/query/suggestions/{name}", + name = crate::common::urlencode(suggestion_name) + ) + .as_str(), + ) + .match_header("Authorization", format!("Token {token}").as_str()) + .create_async() + .await; + + let client = Client::new(mock_server.url(), token); + + let _result = client.query_suggestions_name(suggestion_name).await; + + mock.assert_async().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn query_raw() { + let token = "some-token"; + let org = "some-org"; + let query: Option = Some(Query::new("some-influx-query-string".to_string())); + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock("POST", "/api/v2/query") + .match_header("Authorization", format!("Token {token}").as_str()) + .match_header("Accepting-Encoding", "identity") + .match_header("Content-Type", "application/json") + .match_query(Matcher::UrlEncoded("org".into(), org.into())) + .match_body( + serde_json::to_string(&query.clone().unwrap_or_default()) + .unwrap() + .as_str(), + ) + .create_async() + .await; + + let client = Client::new(mock_server.url(), token); + + let _result = client.query_raw(org, query).await; + + mock.assert_async().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn query_raw_opt() { + let token = "some-token"; + let org = "some-org"; + let query: Option = None; + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock("POST", "/api/v2/query") + .match_header("Authorization", format!("Token {token}").as_str()) + .match_header("Accepting-Encoding", "identity") + .match_header("Content-Type", "application/json") + .match_query(Matcher::UrlEncoded("org".into(), org.into())) + .match_body( + #[allow(clippy::unnecessary_literal_unwrap)] + serde_json::to_string(&query.unwrap_or_default()) + .unwrap() + .as_str(), + ) + .create_async() + .await; + + let client = Client::new(mock_server.url(), token); + + let _result = client.query_raw(org, None).await; + + mock.assert_async().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn query_analyze() { + let token = "some-token"; + let query: Option = Some(Query::new("some-influx-query-string".to_string())); + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock("POST", "/api/v2/query/analyze") + .match_header("Authorization", format!("Token {token}").as_str()) + .match_header("Content-Type", "application/json") + .match_body( + serde_json::to_string(&query.clone().unwrap_or_default()) + .unwrap() + .as_str(), + ) + .create_async() + .await; + + let client = Client::new(mock_server.url(), token); + + let _result = client.query_analyze(query).await; + + mock.assert_async().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn query_analyze_opt() { + let token = "some-token"; + let query: Option = None; + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock("POST", "/api/v2/query/analyze") + .match_header("Authorization", format!("Token {token}").as_str()) + .match_header("Content-Type", "application/json") + .match_body( + serde_json::to_string(&query.clone().unwrap_or_default()) + .unwrap() + .as_str(), + ) + .create_async() + .await; + + let client = Client::new(mock_server.url(), token); + + let _result = client.query_analyze(query).await; + + mock.assert_async().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn query_ast() { + let token = "some-token"; + let language_request: Option = + Some(LanguageRequest::new("some-influx-query-string".to_string())); + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock("POST", "/api/v2/query/ast") + .match_header("Authorization", format!("Token {token}").as_str()) + .match_header("Content-Type", "application/json") + .match_body( + serde_json::to_string(&language_request.clone().unwrap_or_default()) + .unwrap() + .as_str(), + ) + .create_async() + .await; + + let client = Client::new(mock_server.url(), token); + + let _result = client.query_ast(language_request).await; + + mock.assert_async().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn query_ast_opt() { + let token = "some-token"; + let language_request: Option = None; + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock("POST", "/api/v2/query/ast") + .match_header("Authorization", format!("Token {token}").as_str()) + .match_header("Content-Type", "application/json") + .match_body( + serde_json::to_string(&language_request.clone().unwrap_or_default()) + .unwrap() + .as_str(), + ) + .create_async() + .await; + + let client = Client::new(mock_server.url(), token); + + let _result = client.query_ast(language_request).await; + + mock.assert_async().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn query_raw_no_results() { + let token = "some-token"; + let org = "some-org"; + let query: Option = Some(Query::new("some-influx-query-string".to_string())); + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock("POST", "/api/v2/query") + .match_header("Authorization", format!("Token {token}").as_str()) + .match_header("Accepting-Encoding", "identity") + .match_header("Content-Type", "application/json") + .match_query(Matcher::UrlEncoded("org".into(), org.into())) + .match_body( + serde_json::to_string(&query.clone().unwrap_or_default()) + .unwrap() + .as_str(), + ) + .with_body("") + .create_async() + .await; + + let client = Client::new(mock_server.url(), token); + + let result = client.query_raw(org, query).await.expect("request success"); + assert_eq!(result, ""); + + mock.assert_async().await; + } +} diff --git a/influxdb2_client/src/api/ready.rs b/influxdb2_client/src/api/ready.rs new file mode 100644 index 0000000..6765316 --- /dev/null +++ b/influxdb2_client/src/api/ready.rs @@ -0,0 +1,47 @@ +//! Ready +//! +//! Check readiness of an InfluxDB instance at startup + +use reqwest::{Method, StatusCode}; +use snafu::ResultExt; + +use crate::{Client, HttpSnafu, RequestError, ReqwestProcessingSnafu}; + +impl Client { + /// Get the readiness of an instance at startup + pub async fn ready(&self) -> Result { + let ready_url = format!("{}/ready", self.url); + let response = self + .request(Method::GET, &ready_url) + .send() + .await + .context(ReqwestProcessingSnafu)?; + + match response.status() { + StatusCode::OK => Ok(true), + _ => { + let status = response.status(); + let text = response.text().await.context(ReqwestProcessingSnafu)?; + HttpSnafu { status, text }.fail()? + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use mockito::Server; + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn ready() { + let mut mock_server = Server::new_async().await; + let mock = mock_server.mock("GET", "/ready").create_async().await; + + let client = Client::new(mock_server.url(), ""); + + let _result = client.ready().await; + + mock.assert_async().await; + } +} diff --git a/influxdb2_client/src/api/setup.rs b/influxdb2_client/src/api/setup.rs new file mode 100644 index 0000000..590283b --- /dev/null +++ b/influxdb2_client/src/api/setup.rs @@ -0,0 +1,261 @@ +//! Onboarding/Setup +//! +//! Initate and start onboarding process of InfluxDB server. + +use crate::{Client, HttpSnafu, RequestError, ReqwestProcessingSnafu, SerializingSnafu}; +use reqwest::{Method, StatusCode}; +use snafu::ResultExt; + +use crate::models::{IsOnboarding, OnboardingRequest, OnboardingResponse}; + +impl Client { + /// Check if database has default user, org, bucket + pub async fn is_onboarding_allowed(&self) -> Result { + let setup_url = format!("{}/api/v2/setup", self.url); + let response = self + .request(Method::GET, &setup_url) + .send() + .await + .context(ReqwestProcessingSnafu)?; + + match response.status() { + StatusCode::OK => Ok(response + .json::() + .await + .context(ReqwestProcessingSnafu)? + .allowed), + status => { + let text = response.text().await.context(ReqwestProcessingSnafu)?; + HttpSnafu { status, text }.fail()? + } + } + } + + /// Set up initial user, org and bucket + pub async fn onboarding( + &self, + username: &str, + org: &str, + bucket: &str, + password: Option, + retention_period_hrs: Option, + retention_period_seconds: Option, + ) -> Result { + let setup_init_url = format!("{}/api/v2/setup", self.url); + + let body = OnboardingRequest { + username: username.into(), + org: org.into(), + bucket: bucket.into(), + password, + retention_period_hrs, + retention_period_seconds, + }; + + let response = self + .request(Method::POST, &setup_init_url) + .header("Content-Type", "application/json") + .body(serde_json::to_string(&body).context(SerializingSnafu)?) + .send() + .await + .context(ReqwestProcessingSnafu)?; + + match response.status() { + StatusCode::CREATED => Ok(response + .json::() + .await + .context(ReqwestProcessingSnafu)?), + status => { + let text = response.text().await.context(ReqwestProcessingSnafu)?; + HttpSnafu { status, text }.fail()? + } + } + } + + /// Set up a new user, org and bucket + pub async fn post_setup_user( + &self, + username: &str, + org: &str, + bucket: &str, + password: Option, + retention_period_hrs: Option, + retention_period_seconds: Option, + ) -> Result { + let setup_new_url = format!("{}/api/v2/setup/user", self.url); + + let body = OnboardingRequest { + username: username.into(), + org: org.into(), + bucket: bucket.into(), + password, + retention_period_hrs, + retention_period_seconds, + }; + + let response = self + .request(Method::POST, &setup_new_url) + .header("Content-Type", "application/json") + .body(serde_json::to_string(&body).context(SerializingSnafu)?) + .send() + .await + .context(ReqwestProcessingSnafu)?; + + match response.status() { + StatusCode::CREATED => Ok(response + .json::() + .await + .context(ReqwestProcessingSnafu)?), + status => { + let text = response.text().await.context(ReqwestProcessingSnafu)?; + HttpSnafu { status, text }.fail()? + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use mockito::Server; + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn is_onboarding_allowed() { + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock("GET", "/api/v2/setup") + .create_async() + .await; + + let client = Client::new(mock_server.url(), ""); + + let _result = client.is_onboarding_allowed().await; + + mock.assert_async().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn onboarding() { + let token = "some-token"; + let username = "some-user"; + let org = "some-org"; + let bucket = "some-bucket"; + let password = "some-password"; + let retention_period_hrs = 1; + + let mut mock_server = Server::new_async().await; + let mock = mock_server.mock("POST", "/api/v2/setup") + .match_header("Content-Type", "application/json") + .match_body( + format!( + r#"{{"username":"{username}","org":"{org}","bucket":"{bucket}","password":"{password}","retentionPeriodHrs":{retention_period_hrs}}}"# + ).as_str(), + ) + .create_async().await; + + let client = Client::new(mock_server.url(), token); + + let _result = client + .onboarding( + username, + org, + bucket, + Some(password.to_string()), + Some(retention_period_hrs), + None, + ) + .await; + + mock.assert_async().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn post_setup_user() { + let token = "some-token"; + let username = "some-user"; + let org = "some-org"; + let bucket = "some-bucket"; + let password = "some-password"; + let retention_period_hrs = 1; + + let mut mock_server = Server::new_async().await; + let mock = mock_server.mock("POST", "/api/v2/setup/user") + .match_header("Authorization", format!("Token {token}").as_str()) + .match_header("Content-Type", "application/json") + .match_body( + format!( + r#"{{"username":"{username}","org":"{org}","bucket":"{bucket}","password":"{password}","retentionPeriodHrs":{retention_period_hrs}}}"# + ).as_str(), + ) + .create_async().await; + + let client = Client::new(mock_server.url(), token); + + let _result = client + .post_setup_user( + username, + org, + bucket, + Some(password.to_string()), + Some(retention_period_hrs), + None, + ) + .await; + + mock.assert_async().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn onboarding_opt() { + let username = "some-user"; + let org = "some-org"; + let bucket = "some-bucket"; + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock("POST", "/api/v2/setup") + .match_header("Content-Type", "application/json") + .match_body( + format!(r#"{{"username":"{username}","org":"{org}","bucket":"{bucket}"}}"#,) + .as_str(), + ) + .create_async() + .await; + + let client = Client::new(mock_server.url(), ""); + + let _result = client + .onboarding(username, org, bucket, None, None, None) + .await; + + mock.assert_async().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn post_setup_user_opt() { + let token = "some-token"; + let username = "some-user"; + let org = "some-org"; + let bucket = "some-bucket"; + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock("POST", "/api/v2/setup/user") + .match_header("Authorization", format!("Token {token}").as_str()) + .match_header("Content-Type", "application/json") + .match_body( + format!(r#"{{"username":"{username}","org":"{org}","bucket":"{bucket}"}}"#,) + .as_str(), + ) + .create_async() + .await; + + let client = Client::new(mock_server.url(), token); + + let _result = client + .post_setup_user(username, org, bucket, None, None, None) + .await; + + mock.assert_async().await; + } +} diff --git a/influxdb2_client/src/api/write.rs b/influxdb2_client/src/api/write.rs new file mode 100644 index 0000000..f97b555 --- /dev/null +++ b/influxdb2_client/src/api/write.rs @@ -0,0 +1,116 @@ +//! Write API + +use crate::models::WriteDataPoint; +use crate::{Client, HttpSnafu, RequestError, ReqwestProcessingSnafu}; +use bytes::BufMut; +use futures::{Stream, StreamExt}; +use reqwest::{Body, Method}; +use snafu::ResultExt; +use std::io::{self, Write}; + +impl Client { + /// Write line protocol data to the specified organization and bucket. + pub async fn write_line_protocol( + &self, + org: &str, + bucket: &str, + body: impl Into + Send, + ) -> Result<(), RequestError> { + let body = body.into(); + let write_url = format!("{}/api/v2/write", self.url); + + let response = self + .request(Method::POST, &write_url) + .query(&[("bucket", bucket), ("org", org)]) + .body(body) + .send() + .await + .context(ReqwestProcessingSnafu)?; + + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await.context(ReqwestProcessingSnafu)?; + HttpSnafu { status, text }.fail()?; + } + + Ok(()) + } + + /// Write a `Stream` of `DataPoint`s to the specified organization and + /// bucket. + pub async fn write( + &self, + org: &str, + bucket: &str, + body: impl Stream + Send + Sync + 'static, + ) -> Result<(), RequestError> { + let mut buffer = bytes::BytesMut::new(); + + let body = body.map(move |point| { + let mut w = (&mut buffer).writer(); + point.write_data_point_to(&mut w)?; + w.flush()?; + Ok::<_, io::Error>(buffer.split().freeze()) + }); + + let body = Body::wrap_stream(body); + + self.write_line_protocol(org, bucket, body).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::models::DataPoint; + use futures::stream; + use mockito::Server; + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn writing_points() { + let org = "some-org"; + let bucket = "some-bucket"; + let token = "some-token"; + + let mut mock_server = Server::new_async().await; + let mock = mock_server + .mock( + "POST", + format!("/api/v2/write?bucket={bucket}&org={org}").as_str(), + ) + .match_header("Authorization", format!("Token {token}").as_str()) + .match_body( + "\ +cpu,host=server01 usage=0.5 +cpu,host=server01,region=us-west usage=0.87 +", + ) + .create_async() + .await; + + let client = Client::new(mock_server.url(), token); + + let points = vec![ + DataPoint::builder("cpu") + .tag("host", "server01") + .field("usage", 0.5) + .build() + .unwrap(), + DataPoint::builder("cpu") + .tag("host", "server01") + .tag("region", "us-west") + .field("usage", 0.87) + .build() + .unwrap(), + ]; + + // If the requests made are incorrect, Mockito returns status 501 and `write` + // will return an error, which causes the test to fail here instead of + // when we assert on mock_server. The error messages that Mockito + // provides are much clearer for explaining why a test failed than just + // that the server returned 501, so don't use `?` here. + let _result = client.write(org, bucket, stream::iter(points)).await; + + mock.assert_async().await; + } +} diff --git a/influxdb2_client/src/common.rs b/influxdb2_client/src/common.rs new file mode 100644 index 0000000..d51ea20 --- /dev/null +++ b/influxdb2_client/src/common.rs @@ -0,0 +1,8 @@ +//! Common +//! +//! Collection of helper functions + +/// Serialize to application/x-www-form-urlencoded syntax +pub fn urlencode>(s: T) -> String { + ::url::form_urlencoded::byte_serialize(s.as_ref().as_bytes()).collect() +} diff --git a/influxdb2_client/src/lib.rs b/influxdb2_client/src/lib.rs new file mode 100644 index 0000000..0db577e --- /dev/null +++ b/influxdb2_client/src/lib.rs @@ -0,0 +1,203 @@ +#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)] +// `clippy::use_self` is deliberately excluded from the lints this crate uses. +// See . +#![warn( + missing_copy_implementations, + missing_debug_implementations, + missing_docs, + clippy::explicit_iter_loop, + clippy::clone_on_ref_ptr, + // See https://github.com/influxdata/influxdb_iox/pull/1671 + clippy::future_not_send, + clippy::todo, + clippy::dbg_macro, + unused_crate_dependencies +)] + +//! # influxdb2_client +//! +//! This is a Rust client to InfluxDB using the [2.0 API][2api]. +//! +//! [2api]: https://v2.docs.influxdata.com/v2.0/reference/api/ +//! +//! ## Work Remaining +//! +//! - Query +//! - optional sync client +//! - Influx 1.x API? +//! - Other parts of the API +//! - Pick the best name to use on crates.io and publish +//! +//! ## Quick start +//! +//! This example creates a client to an InfluxDB server running at `http://localhost:8888`, creates +//! a bucket with the name "mybucket" in the organization with name "myorg" and +//! ID "0000111100001111", builds two points, and writes the points to the +//! bucket. +//! +//! ``` +//! async fn example() -> Result<(), Box> { +//! use influxdb2_client::Client; +//! use influxdb2_client::models::{DataPoint, PostBucketRequest}; +//! use futures::stream; +//! +//! let org = "myorg"; +//! let org_id = "0000111100001111"; +//! let bucket = "mybucket"; +//! +//! let client = Client::new("http://localhost:8888", "some-token"); +//! +//! client.create_bucket( +//! Some(PostBucketRequest::new(org_id.to_string(), bucket.to_string())) +//! ).await?; +//! +//! let points = vec![ +//! DataPoint::builder("cpu") +//! .tag("host", "server01") +//! .field("usage", 0.5) +//! .build()?, +//! DataPoint::builder("cpu") +//! .tag("host", "server01") +//! .tag("region", "us-west") +//! .field("usage", 0.87) +//! .build()?, +//! ]; +//! +//! client.write(org, bucket, stream::iter(points)).await?; +//! Ok(()) +//! } +//! ``` + +// Workaround for "unused crate" lint false positives. +#[cfg(test)] +use once_cell as _; +#[cfg(test)] +use parking_lot as _; +#[cfg(test)] +use test_helpers as _; + +use reqwest::Method; +use snafu::Snafu; + +/// Errors that occur while making requests to the Influx server. +#[derive(Debug, Snafu)] +pub enum RequestError { + /// While making a request to the Influx server, the underlying `reqwest` + /// library returned an error that was not an HTTP 400 or 500. + #[snafu(display("Error while processing the HTTP request: {}", source))] + ReqwestProcessing { + /// The underlying error object from `reqwest`. + source: reqwest::Error, + }, + /// The underlying `reqwest` library returned an HTTP error with code 400 + /// (meaning a client error) or 500 (meaning a server error). + #[snafu(display("HTTP request returned an error: {}, `{}`", status, text))] + Http { + /// The `StatusCode` returned from the request + status: reqwest::StatusCode, + /// Any text data returned from the request + text: String, + }, + + /// While serializing data as JSON to send in a request, the underlying + /// `serde_json` library returned an error. + #[snafu(display("Error while serializing to JSON: {}", source))] + Serializing { + /// The underlying error object from `serde_json`. + source: serde_json::error::Error, + }, + + /// While deserializing the response as JSON, something went wrong. + #[snafu(display("Could not deserialize as JSON. Error: {source}\nText: `{text}`"))] + DeserializingJsonResponse { + /// The text of the response + text: String, + /// The underlying error object from serde + source: serde_json::Error, + }, + + /// Something went wrong getting the raw bytes of the response + #[snafu(display("Could not get response bytes: {source}"))] + ResponseBytes { + /// The underlying error object from reqwest + source: reqwest::Error, + }, + + /// Something went wrong converting the raw bytes of the response to a UTF-8 string + #[snafu(display("Invalid UTF-8: {source}"))] + ResponseString { + /// The underlying error object from std + source: std::string::FromUtf8Error, + }, +} + +/// Client to a server supporting the InfluxData 2.0 API. +#[derive(Debug, Clone)] +pub struct Client { + /// The base URL this client sends requests to + pub url: String, + auth_header: Option, + reqwest: reqwest::Client, + jaeger_debug_header: Option, +} + +impl Client { + /// Default [jaeger debug header](Self::with_jaeger_debug) that should work in many + /// environments. + pub const DEFAULT_JAEGER_DEBUG_HEADER: &'static str = "jaeger-debug-id"; + + /// Create a new client pointing to the URL specified in + /// `protocol://server:port` format and using the specified token for + /// authorization. + /// + /// # Example + /// + /// ``` + /// let client = influxdb2_client::Client::new("http://localhost:8888", "my-token"); + /// ``` + pub fn new(url: impl Into, auth_token: impl Into) -> Self { + let token = auth_token.into(); + let auth_header = if token.is_empty() { + None + } else { + Some(format!("Token {token}")) + }; + + Self { + url: url.into(), + auth_header, + reqwest: reqwest::Client::builder() + .connection_verbose(true) + .build() + .expect("reqwest::Client should have built"), + jaeger_debug_header: None, + } + } + + /// Enable generation of jaeger debug headers with the given header name. + pub fn with_jaeger_debug(self, header: String) -> Self { + Self { + jaeger_debug_header: Some(header), + ..self + } + } + + /// Consolidate common request building code + fn request(&self, method: Method, url: &str) -> reqwest::RequestBuilder { + let mut req = self.reqwest.request(method, url); + + if let Some(auth) = &self.auth_header { + req = req.header("Authorization", auth); + } + if let Some(header) = &self.jaeger_debug_header { + req = req.header(header, format!("influxdb_client-{}", uuid::Uuid::new_v4())); + } + + req + } +} + +pub mod common; + +pub mod api; +pub mod models; diff --git a/influxdb2_client/src/models/ast/call_expression.rs b/influxdb2_client/src/models/ast/call_expression.rs new file mode 100644 index 0000000..23dcc1c --- /dev/null +++ b/influxdb2_client/src/models/ast/call_expression.rs @@ -0,0 +1,24 @@ +//! CallExpression + +use serde::{Deserialize, Serialize}; + +/// Represents a function call +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct CallExpression { + /// Type of AST node + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, + /// Callee + #[serde(skip_serializing_if = "Option::is_none")] + pub callee: Option>, + /// Function arguments + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub arguments: Vec, +} + +impl CallExpression { + /// Represents a function call + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/ast/dialect.rs b/influxdb2_client/src/models/ast/dialect.rs new file mode 100644 index 0000000..d4e5c78 --- /dev/null +++ b/influxdb2_client/src/models/ast/dialect.rs @@ -0,0 +1,54 @@ +//! Dialect + +use serde::{Deserialize, Serialize}; + +/// Dialect are options to change the default CSV output format; +/// +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Dialect { + /// If true, the results will contain a header row + #[serde(skip_serializing_if = "Option::is_none")] + pub header: Option, + /// Separator between cells; the default is , + #[serde(skip_serializing_if = "Option::is_none")] + pub delimiter: Option, + /// + #[serde(skip_serializing_if = "Option::is_none")] + pub annotations: Option>, + /// Character prefixed to comment strings + #[serde(skip_serializing_if = "Option::is_none")] + pub comment_prefix: Option, + /// Format of timestamps + #[serde(skip_serializing_if = "Option::is_none")] + pub date_time_format: Option, +} + +impl Dialect { + /// Dialect are options to change the default CSV output format; + /// + pub fn new() -> Self { + Self::default() + } +} + +/// +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum Annotations { + /// Group Annotation + Group, + /// Datatype Annotation + Datatype, + /// Default Annotation + Default, +} + +/// Timestamp Format +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +pub enum DateTimeFormat { + /// RFC3339 + Rfc3339, + /// RFC3339Nano + Rfc3339Nano, +} diff --git a/influxdb2_client/src/models/ast/dict_item.rs b/influxdb2_client/src/models/ast/dict_item.rs new file mode 100644 index 0000000..1f1fa40 --- /dev/null +++ b/influxdb2_client/src/models/ast/dict_item.rs @@ -0,0 +1,24 @@ +//! DictItem + +use serde::{Deserialize, Serialize}; + +/// A key/value pair in a dictionary +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct DictItem { + /// Type of AST node + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, + /// Key + #[serde(skip_serializing_if = "Option::is_none")] + pub key: Option, + /// Value + #[serde(skip_serializing_if = "Option::is_none")] + pub val: Option, +} + +impl DictItem { + /// A key/value pair in a dictionary + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/ast/duration.rs b/influxdb2_client/src/models/ast/duration.rs new file mode 100644 index 0000000..fae6a61 --- /dev/null +++ b/influxdb2_client/src/models/ast/duration.rs @@ -0,0 +1,27 @@ +//! Duration + +use serde::{Deserialize, Serialize}; + +/// Duration : A pair consisting of length of time and the unit of time +/// measured. It is the atomic unit from which all duration literals are +/// composed. +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct Duration { + /// Type of AST node + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, + /// Duration Magnitude + #[serde(skip_serializing_if = "Option::is_none")] + pub magnitude: Option, + /// Duration unit + #[serde(skip_serializing_if = "Option::is_none")] + pub unit: Option, +} + +impl Duration { + /// A pair consisting of length of time and the unit of time measured. It is + /// the atomic unit from which all duration literals are composed. + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/ast/expression.rs b/influxdb2_client/src/models/ast/expression.rs new file mode 100644 index 0000000..5d4b48b --- /dev/null +++ b/influxdb2_client/src/models/ast/expression.rs @@ -0,0 +1,84 @@ +//! Expression + +use serde::{Deserialize, Serialize}; + +/// Expression AST +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct Expression { + /// Type of AST node + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, + /// Elements of the dictionary + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub elements: Vec, + /// Function parameters + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub params: Vec, + /// Node + #[serde(skip_serializing_if = "Option::is_none")] + pub body: Option, + /// Operator + #[serde(skip_serializing_if = "Option::is_none")] + pub operator: Option, + /// Left leaf + #[serde(skip_serializing_if = "Option::is_none")] + pub left: Option>, + /// Right leaf + #[serde(skip_serializing_if = "Option::is_none")] + pub right: Option>, + /// Parent Expression + #[serde(skip_serializing_if = "Option::is_none")] + pub callee: Option>, + /// Function arguments + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub arguments: Vec, + /// Test Expr + #[serde(skip_serializing_if = "Option::is_none")] + pub test: Option>, + /// Alternate Expr + #[serde(skip_serializing_if = "Option::is_none")] + pub alternate: Option>, + /// Consequent Expr + #[serde(skip_serializing_if = "Option::is_none")] + pub consequent: Option>, + /// Object Expr + #[serde(skip_serializing_if = "Option::is_none")] + pub object: Option>, + /// PropertyKey Expr + #[serde(skip_serializing_if = "Option::is_none")] + pub property: Option>, + /// Array Expr + #[serde(skip_serializing_if = "Option::is_none")] + pub array: Option>, + /// Index Expr + #[serde(skip_serializing_if = "Option::is_none")] + pub index: Option>, + /// Properties + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub properties: Vec, + /// Expression + #[serde(skip_serializing_if = "Option::is_none")] + pub expression: Option>, + /// Argument + #[serde(skip_serializing_if = "Option::is_none")] + pub argument: Option>, + /// Call Expr + #[serde(skip_serializing_if = "Option::is_none")] + pub call: Option, + /// Expression Value + #[serde(skip_serializing_if = "Option::is_none")] + pub value: Option, + /// Duration values + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub values: Vec, + /// Expression Name + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, +} + +impl Expression { + /// Return instance of expression + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/ast/identifier.rs b/influxdb2_client/src/models/ast/identifier.rs new file mode 100644 index 0000000..361c223 --- /dev/null +++ b/influxdb2_client/src/models/ast/identifier.rs @@ -0,0 +1,21 @@ +//! Idendifier + +use serde::{Deserialize, Serialize}; + +/// A valid Flux identifier +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct Identifier { + /// Type of AST node + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, + /// Identifier Name + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, +} + +impl Identifier { + /// A valid Flux identifier + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/ast/import_declaration.rs b/influxdb2_client/src/models/ast/import_declaration.rs new file mode 100644 index 0000000..5caec07 --- /dev/null +++ b/influxdb2_client/src/models/ast/import_declaration.rs @@ -0,0 +1,24 @@ +//! ImportDeclaration + +use serde::{Deserialize, Serialize}; + +/// Declares a package import +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct ImportDeclaration { + /// Type of AST node + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, + /// Import Identifier + #[serde(rename = "as", skip_serializing_if = "Option::is_none")] + pub r#as: Option, + /// Import Path + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, +} + +impl ImportDeclaration { + /// Declares a package import + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/ast/member_expression.rs b/influxdb2_client/src/models/ast/member_expression.rs new file mode 100644 index 0000000..d3e4dba --- /dev/null +++ b/influxdb2_client/src/models/ast/member_expression.rs @@ -0,0 +1,24 @@ +//! MemberExpression + +use serde::{Deserialize, Serialize}; + +/// Represents accessing a property of an object +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct MemberExpression { + /// Type of AST node + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, + /// Member object + #[serde(skip_serializing_if = "Option::is_none")] + pub object: Option, + /// Member Property + #[serde(skip_serializing_if = "Option::is_none")] + pub property: Option, +} + +impl MemberExpression { + /// Represents accessing a property of an object + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/ast/mod.rs b/influxdb2_client/src/models/ast/mod.rs new file mode 100644 index 0000000..f170a85 --- /dev/null +++ b/influxdb2_client/src/models/ast/mod.rs @@ -0,0 +1,34 @@ +//! Query AST models + +pub mod identifier; +pub use self::identifier::Identifier; +pub mod statement; +pub use self::statement::Statement; +pub mod expression; +pub use self::expression::Expression; +pub mod call_expression; +pub use self::call_expression::CallExpression; +pub mod member_expression; +pub use self::member_expression::MemberExpression; +pub mod string_literal; +pub use self::string_literal::StringLiteral; +pub mod dict_item; +pub use self::dict_item::DictItem; +pub mod variable_assignment; +pub use self::variable_assignment::VariableAssignment; +pub mod node; +pub use self::node::Node; +pub mod property; +pub use self::property::Property; +pub mod property_key; +pub use self::property_key::PropertyKey; +pub mod dialect; +pub use self::dialect::Dialect; +pub mod import_declaration; +pub use self::import_declaration::ImportDeclaration; +pub mod package; +pub use self::package::Package; +pub mod package_clause; +pub use self::package_clause::PackageClause; +pub mod duration; +pub use self::duration::Duration; diff --git a/influxdb2_client/src/models/ast/node.rs b/influxdb2_client/src/models/ast/node.rs new file mode 100644 index 0000000..e6bfcf5 --- /dev/null +++ b/influxdb2_client/src/models/ast/node.rs @@ -0,0 +1,84 @@ +//! Node + +use serde::{Deserialize, Serialize}; + +/// Node +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct Node { + /// Type of AST node + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, + /// Elements of the dictionary + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub elements: Vec, + /// Function parameters + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub params: Vec, + /// Block body + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub body: Vec, + /// Node Operator + #[serde(skip_serializing_if = "Option::is_none")] + pub operator: Option, + /// Left left node + #[serde(skip_serializing_if = "Option::is_none")] + pub left: Option>, + /// Right right node + #[serde(skip_serializing_if = "Option::is_none")] + pub right: Option>, + /// Parent node + #[serde(skip_serializing_if = "Option::is_none")] + pub callee: Option>, + /// Function arguments + #[serde(skip_serializing_if = "Vec::is_empty")] + pub arguments: Vec, + /// Test Expr + #[serde(skip_serializing_if = "Option::is_none")] + pub test: Option>, + /// Alternate Expr + #[serde(skip_serializing_if = "Option::is_none")] + pub alternate: Option>, + /// Consequent Expr + #[serde(skip_serializing_if = "Option::is_none")] + pub consequent: Option>, + /// Object Expr + #[serde(skip_serializing_if = "Option::is_none")] + pub object: Option>, + /// PropertyKey + #[serde(skip_serializing_if = "Option::is_none")] + pub property: Option, + /// Array Expr + #[serde(skip_serializing_if = "Option::is_none")] + pub array: Option>, + /// Index Expr + #[serde(skip_serializing_if = "Option::is_none")] + pub index: Option>, + /// Object properties + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub properties: Vec, + /// Expression + #[serde(skip_serializing_if = "Option::is_none")] + pub expression: Option>, + /// Node arguments + #[serde(skip_serializing_if = "Option::is_none")] + pub argument: Option>, + /// Call Expr + #[serde(skip_serializing_if = "Option::is_none")] + pub call: Option, + /// Node Value + #[serde(skip_serializing_if = "Option::is_none")] + pub value: Option, + /// Duration values + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub values: Vec, + /// Node name + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, +} + +impl Node { + /// Return instance of Node + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/ast/package.rs b/influxdb2_client/src/models/ast/package.rs new file mode 100644 index 0000000..84f97a0 --- /dev/null +++ b/influxdb2_client/src/models/ast/package.rs @@ -0,0 +1,28 @@ +//! Package + +use crate::models::File; +use serde::{Deserialize, Serialize}; + +/// Represents a complete package source tree. +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct Package { + /// Type of AST node + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, + /// Package import path + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, + /// Package name + #[serde(skip_serializing_if = "Option::is_none")] + pub package: Option, + /// Package files + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub files: Vec, +} + +impl Package { + /// Represents a complete package source tree. + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/ast/package_clause.rs b/influxdb2_client/src/models/ast/package_clause.rs new file mode 100644 index 0000000..ec90473 --- /dev/null +++ b/influxdb2_client/src/models/ast/package_clause.rs @@ -0,0 +1,21 @@ +//! PackageClause + +use serde::{Deserialize, Serialize}; + +/// Defines a package identifier +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct PackageClause { + /// Type of AST node + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, + /// Package name + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, +} + +impl PackageClause { + /// Defines a package identifier + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/ast/property.rs b/influxdb2_client/src/models/ast/property.rs new file mode 100644 index 0000000..ee5efd2 --- /dev/null +++ b/influxdb2_client/src/models/ast/property.rs @@ -0,0 +1,24 @@ +//! Property + +use serde::{Deserialize, Serialize}; + +/// The value associated with a key +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct Property { + /// Type of AST node + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, + /// Property Key + #[serde(skip_serializing_if = "Option::is_none")] + pub key: Option, + /// Property Value + #[serde(skip_serializing_if = "Option::is_none")] + pub value: Option, +} + +impl Property { + /// The value associated with a key + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/ast/property_key.rs b/influxdb2_client/src/models/ast/property_key.rs new file mode 100644 index 0000000..71f521c --- /dev/null +++ b/influxdb2_client/src/models/ast/property_key.rs @@ -0,0 +1,24 @@ +//! PropertyKey + +use serde::{Deserialize, Serialize}; + +/// Key value pair +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct PropertyKey { + /// Type of AST node + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, + /// PropertyKey name + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + /// PropertyKey value + #[serde(skip_serializing_if = "Option::is_none")] + pub value: Option, +} + +impl PropertyKey { + /// Returns an instance of PropertyKey + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/ast/statement.rs b/influxdb2_client/src/models/ast/statement.rs new file mode 100644 index 0000000..edbc526 --- /dev/null +++ b/influxdb2_client/src/models/ast/statement.rs @@ -0,0 +1,39 @@ +//! Statement + +use serde::{Deserialize, Serialize}; + +/// Expression Statement +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct Statement { + /// Type of AST node + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, + /// Raw source text + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + /// Statement identitfier + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + /// Initial Value + #[serde(skip_serializing_if = "Option::is_none")] + pub init: Option, + /// Member + #[serde(skip_serializing_if = "Option::is_none")] + pub member: Option, + /// Expression + #[serde(skip_serializing_if = "Option::is_none")] + pub expression: Option, + /// Argument + #[serde(skip_serializing_if = "Option::is_none")] + pub argument: Option, + /// Assignment + #[serde(skip_serializing_if = "Option::is_none")] + pub assignment: Option, +} + +impl Statement { + /// Returns an instance of Statement + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/ast/string_literal.rs b/influxdb2_client/src/models/ast/string_literal.rs new file mode 100644 index 0000000..e5c144e --- /dev/null +++ b/influxdb2_client/src/models/ast/string_literal.rs @@ -0,0 +1,21 @@ +//! StringLiteral + +use serde::{Deserialize, Serialize}; + +/// Expressions begin and end with double quote marks +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct StringLiteral { + /// Type of AST node + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, + /// StringLiteral Value + #[serde(skip_serializing_if = "Option::is_none")] + pub value: Option, +} + +impl StringLiteral { + /// Expressions begin and end with double quote marks + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/ast/variable_assignment.rs b/influxdb2_client/src/models/ast/variable_assignment.rs new file mode 100644 index 0000000..f9f6111 --- /dev/null +++ b/influxdb2_client/src/models/ast/variable_assignment.rs @@ -0,0 +1,24 @@ +//! VariableAssignment + +use serde::{Deserialize, Serialize}; + +/// Represents the declaration of a variable +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct VariableAssignment { + /// Type of AST node + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, + /// Variable Identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + /// Variable initial value + #[serde(skip_serializing_if = "Option::is_none")] + pub init: Option, +} + +impl VariableAssignment { + /// Represents the declaration of a variable + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/authorization.rs b/influxdb2_client/src/models/authorization.rs new file mode 100644 index 0000000..fe4878d --- /dev/null +++ b/influxdb2_client/src/models/authorization.rs @@ -0,0 +1,88 @@ +//! Authorization +//! +//! Auth tokens for InfluxDB + +use serde::{Deserialize, Serialize}; + +/// Authorization to create +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Authorization { + /// If inactive the token is inactive and requests using the token will be + /// rejected. + #[serde(skip_serializing_if = "Option::is_none")] + pub status: Option, + /// A description of the token. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Auth created_at + #[serde(skip_serializing_if = "Option::is_none")] + pub created_at: Option, + /// Auth updated_at + #[serde(skip_serializing_if = "Option::is_none")] + pub updated_at: Option, + /// ID of org that authorization is scoped to. + #[serde(rename = "orgID")] + pub org_id: String, + /// List of permissions for an auth. An auth must have at least one + /// Permission. + pub permissions: Vec, + /// Auth ID. + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + /// Passed via the Authorization Header and Token Authentication type. + #[serde(skip_serializing_if = "Option::is_none")] + pub token: Option, + /// ID of user that created and owns the token. + #[serde(rename = "userID", skip_serializing_if = "Option::is_none")] + pub user_id: Option, + /// Name of user that created and owns the token. + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + /// Name of the org token is scoped to. + #[serde(skip_serializing_if = "Option::is_none")] + pub org: Option, + /// Links + #[serde(skip_serializing_if = "Option::is_none")] + pub links: Option, +} + +impl Authorization { + /// Returns an Authorization with the given orgID and permissions + pub fn new(org_id: String, permissions: Vec) -> Self { + Self { + org_id, + permissions, + ..Default::default() + } + } +} + +/// If inactive the token is inactive and requests using the token will be +/// rejected. +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum Status { + /// Token is active. + Active, + /// Token is inactive. + Inactive, +} + +/// AuthorizationAllOfLinks +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct AuthorizationAllOfLinks { + /// Self + #[serde(rename = "self", skip_serializing_if = "Option::is_none")] + pub self_: Option, + /// User + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + +impl AuthorizationAllOfLinks { + /// Return an instance of AuthorizationAllOfLinks + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/bucket.rs b/influxdb2_client/src/models/bucket.rs new file mode 100644 index 0000000..432b785 --- /dev/null +++ b/influxdb2_client/src/models/bucket.rs @@ -0,0 +1,142 @@ +//! Bucket + +use serde::{Deserialize, Serialize}; + +/// Bucket Schema +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Bucket { + /// BucketLinks + #[serde(skip_serializing_if = "Option::is_none")] + pub links: Option, + /// Bucket ID + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + /// Bucket Type + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, + /// Bucket name + pub name: String, + /// Bucket description + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Organization ID of bucket + #[serde(rename = "orgID", skip_serializing_if = "Option::is_none")] + pub org_id: Option, + /// RP + #[serde(skip_serializing_if = "Option::is_none")] + pub rp: Option, + /// Created At + #[serde(skip_serializing_if = "Option::is_none")] + pub created_at: Option, + /// Updated At + #[serde(skip_serializing_if = "Option::is_none")] + pub updated_at: Option, + /// Rules to expire or retain data. No rules means data never expires. + pub retention_rules: Vec, + /// Bucket labels + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub labels: Vec, +} + +impl Bucket { + /// Returns instance of Bucket + pub fn new(name: String, retention_rules: Vec) -> Self { + Self { + name, + retention_rules, + ..Default::default() + } + } +} + +/// Bucket Type +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum Type { + /// User + User, + /// System + System, +} + +/// Bucket links +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct BucketLinks { + /// Labels + #[serde(skip_serializing_if = "Option::is_none")] + pub labels: Option, + /// Members + #[serde(skip_serializing_if = "Option::is_none")] + pub members: Option, + /// Organization + #[serde(skip_serializing_if = "Option::is_none")] + pub org: Option, + /// Owners + #[serde(skip_serializing_if = "Option::is_none")] + pub owners: Option, + /// Self + #[serde(rename = "self", skip_serializing_if = "Option::is_none")] + pub self_: Option, + /// Write + #[serde(skip_serializing_if = "Option::is_none")] + pub write: Option, +} + +impl BucketLinks { + /// Returns instance of BucketLinks + pub fn new() -> Self { + Self::default() + } +} + +/// List all buckets +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Buckets { + /// Links + #[serde(skip_serializing_if = "Option::is_none")] + pub links: Option, + /// Buckets + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub buckets: Vec, +} + +impl Buckets { + /// Returns list of buckets + pub fn new() -> Self { + Self::default() + } +} + +/// PostBucketRequest +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +pub struct PostBucketRequest { + /// Organization ID + #[serde(rename = "orgID")] + pub org_id: String, + /// Bucket name + pub name: String, + /// Bucket Description + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// RP + #[serde(skip_serializing_if = "Option::is_none")] + pub rp: Option, + /// Rules to expire or retain data. No rules means data never expires. + #[serde(default)] + pub retention_rules: Vec, +} + +impl PostBucketRequest { + /// Returns instance of PostBucketRequest + pub fn new(org_id: String, name: String) -> Self { + Self { + org_id, + name, + ..Default::default() + } + } +} diff --git a/influxdb2_client/src/models/data_point.rs b/influxdb2_client/src/models/data_point.rs new file mode 100644 index 0000000..ff4877a --- /dev/null +++ b/influxdb2_client/src/models/data_point.rs @@ -0,0 +1,511 @@ +//! Data point building and writing + +use snafu::{ensure, Snafu}; +use std::{collections::BTreeMap, io}; + +/// Errors that occur while building `DataPoint`s +#[derive(Debug, Snafu)] +pub enum DataPointError { + /// Returned when calling `build` on a `DataPointBuilder` that has no + /// fields. + #[snafu(display( + "All `DataPoints` must have at least one field. Builder contains: {:?}", + data_point_builder + ))] + AtLeastOneFieldRequired { + /// The current state of the `DataPointBuilder` + data_point_builder: DataPointBuilder, + }, +} + +/// Incrementally constructs a `DataPoint`. +/// +/// Create this via `DataPoint::builder`. +#[derive(Debug)] +pub struct DataPointBuilder { + measurement: String, + // Keeping the tags sorted improves performance on the server side + tags: BTreeMap, + fields: BTreeMap, + timestamp: Option, +} + +impl DataPointBuilder { + fn new(measurement: impl Into) -> Self { + Self { + measurement: measurement.into(), + tags: Default::default(), + fields: Default::default(), + timestamp: Default::default(), + } + } + + /// Sets a tag, replacing any existing tag of the same name. + pub fn tag(mut self, name: impl Into, value: impl Into) -> Self { + self.tags.insert(name.into(), value.into()); + self + } + + /// Sets a field, replacing any existing field of the same name. + pub fn field(mut self, name: impl Into, value: impl Into) -> Self { + self.fields.insert(name.into(), value.into()); + self + } + + /// Sets the timestamp, replacing any existing timestamp. + /// + /// The value is treated as the number of nanoseconds since the + /// UNIX epoch. + pub fn timestamp(mut self, value: i64) -> Self { + self.timestamp = Some(value); + self + } + + /// Constructs the data point + pub fn build(self) -> Result { + ensure!( + !self.fields.is_empty(), + AtLeastOneFieldRequiredSnafu { + data_point_builder: self + } + ); + + let Self { + measurement, + tags, + fields, + timestamp, + } = self; + + Ok(DataPoint { + measurement, + tags, + fields, + timestamp, + }) + } +} + +/// A single point of information to send to InfluxDB. +// TODO: If we want to support non-UTF-8 data, all `String`s stored in `DataPoint` would need +// to be `Vec` instead, the API for creating a `DataPoint` would need some more consideration, +// and there would need to be more `Write*` trait implementations. Because the `Write*` traits work +// on a writer of bytes, that part of the design supports non-UTF-8 data now. +#[derive(Debug)] +pub struct DataPoint { + measurement: String, + tags: BTreeMap, + fields: BTreeMap, + timestamp: Option, +} + +impl DataPoint { + /// Create a builder to incrementally construct a `DataPoint`. + pub fn builder(measurement: impl Into) -> DataPointBuilder { + DataPointBuilder::new(measurement) + } +} + +impl WriteDataPoint for DataPoint { + fn write_data_point_to(&self, mut w: W) -> io::Result<()> + where + W: io::Write, + { + self.measurement.write_measurement_to(&mut w)?; + + for (k, v) in &self.tags { + w.write_all(b",")?; + k.write_tag_key_to(&mut w)?; + w.write_all(b"=")?; + v.write_tag_value_to(&mut w)?; + } + + for (i, (k, v)) in self.fields.iter().enumerate() { + let d = if i == 0 { b" " } else { b"," }; + + w.write_all(d)?; + k.write_field_key_to(&mut w)?; + w.write_all(b"=")?; + v.write_field_value_to(&mut w)?; + } + + if let Some(ts) = self.timestamp { + w.write_all(b" ")?; + ts.write_timestamp_to(&mut w)?; + } + + w.write_all(b"\n")?; + + Ok(()) + } +} + +/// Possible value types +#[derive(Debug, Clone, PartialEq)] +pub enum FieldValue { + /// A true or false value + Bool(bool), + /// A 64-bit floating point number + F64(f64), + /// A 64-bit signed integer number + I64(i64), + /// A 64-bit unsigned integer number + U64(u64), + /// A string value + String(String), +} + +impl From for FieldValue { + fn from(other: bool) -> Self { + Self::Bool(other) + } +} + +impl From for FieldValue { + fn from(other: f64) -> Self { + Self::F64(other) + } +} + +impl From for FieldValue { + fn from(other: i64) -> Self { + Self::I64(other) + } +} + +impl From for FieldValue { + fn from(other: u64) -> Self { + Self::U64(other) + } +} + +impl From<&str> for FieldValue { + fn from(other: &str) -> Self { + Self::String(other.into()) + } +} + +impl From for FieldValue { + fn from(other: String) -> Self { + Self::String(other) + } +} + +/// Transform a type into valid line protocol lines +/// +/// This trait is to enable the conversion of `DataPoint`s to line protocol; it +/// is unlikely that you would need to implement this trait. In the future, a +/// `derive` crate may exist that would facilitate the generation of +/// implementations of this trait on custom types to help uphold the +/// responsibilities for escaping and producing complete lines. +pub trait WriteDataPoint { + /// Write this data point as line protocol. The implementor is responsible + /// for properly escaping the data and ensuring that complete lines + /// are generated. + fn write_data_point_to(&self, w: W) -> io::Result<()> + where + W: io::Write; +} + +// The following are traits rather than free functions so that we can limit +// their implementations to only the data types supported for each of +// measurement, tag key, tag value, field key, field value, and timestamp. They +// are a private implementation detail and any custom implementations +// of these traits would be generated by a future derive trait. +trait WriteMeasurement { + fn write_measurement_to(&self, w: W) -> io::Result<()> + where + W: io::Write; +} + +impl WriteMeasurement for str { + fn write_measurement_to(&self, w: W) -> io::Result<()> + where + W: io::Write, + { + escape_and_write_value(self, MEASUREMENT_DELIMITERS, w) + } +} + +trait WriteTagKey { + fn write_tag_key_to(&self, w: W) -> io::Result<()> + where + W: io::Write; +} + +impl WriteTagKey for str { + fn write_tag_key_to(&self, w: W) -> io::Result<()> + where + W: io::Write, + { + escape_and_write_value(self, TAG_KEY_DELIMITERS, w) + } +} + +trait WriteTagValue { + fn write_tag_value_to(&self, w: W) -> io::Result<()> + where + W: io::Write; +} + +impl WriteTagValue for str { + fn write_tag_value_to(&self, w: W) -> io::Result<()> + where + W: io::Write, + { + escape_and_write_value(self, TAG_VALUE_DELIMITERS, w) + } +} + +trait WriteFieldKey { + fn write_field_key_to(&self, w: W) -> io::Result<()> + where + W: io::Write; +} + +impl WriteFieldKey for str { + fn write_field_key_to(&self, w: W) -> io::Result<()> + where + W: io::Write, + { + escape_and_write_value(self, FIELD_KEY_DELIMITERS, w) + } +} + +trait WriteFieldValue { + fn write_field_value_to(&self, w: W) -> io::Result<()> + where + W: io::Write; +} + +impl WriteFieldValue for FieldValue { + fn write_field_value_to(&self, mut w: W) -> io::Result<()> + where + W: io::Write, + { + use FieldValue::*; + + match self { + Bool(v) => write!(w, "{}", if *v { "t" } else { "f" }), + F64(v) => write!(w, "{v}"), + I64(v) => write!(w, "{v}i"), + U64(v) => write!(w, "{v}u"), + String(v) => { + w.write_all(br#"""#)?; + escape_and_write_value(v, FIELD_VALUE_STRING_DELIMITERS, &mut w)?; + w.write_all(br#"""#) + } + } + } +} + +trait WriteTimestamp { + fn write_timestamp_to(&self, w: W) -> io::Result<()> + where + W: io::Write; +} + +impl WriteTimestamp for i64 { + fn write_timestamp_to(&self, mut w: W) -> io::Result<()> + where + W: io::Write, + { + write!(w, "{self}") + } +} + +const MEASUREMENT_DELIMITERS: &[char] = &[',', ' ']; +const TAG_KEY_DELIMITERS: &[char] = &[',', '=', ' ']; +const TAG_VALUE_DELIMITERS: &[char] = TAG_KEY_DELIMITERS; +const FIELD_KEY_DELIMITERS: &[char] = TAG_KEY_DELIMITERS; +const FIELD_VALUE_STRING_DELIMITERS: &[char] = &['"']; + +fn escape_and_write_value( + value: &str, + escaping_specification: &[char], + mut w: W, +) -> io::Result<()> +where + W: io::Write, +{ + let mut last = 0; + + for (idx, delim) in value.match_indices(escaping_specification) { + let s = &value[last..idx]; + write!(w, r#"{s}\{delim}"#)?; + last = idx + delim.len(); + } + + w.write_all(value[last..].as_bytes()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::str; + + fn assert_utf8_strings_eq(left: &[u8], right: &[u8]) { + assert_eq!( + left, + right, + "\n\nleft string value: `{}`,\nright string value: `{}`", + str::from_utf8(left).unwrap(), + str::from_utf8(right).unwrap(), + ); + } + + #[test] + fn point_builder_allows_setting_tags_and_fields() { + let point = DataPoint::builder("swap") + .tag("host", "server01") + .tag("name", "disk0") + .field("in", 3_i64) + .field("out", 4_i64) + .timestamp(1) + .build() + .unwrap(); + + assert_utf8_strings_eq( + &point.data_point_to_vec().unwrap(), + b"swap,host=server01,name=disk0 in=3i,out=4i 1\n".as_ref(), + ); + } + + #[test] + fn no_tags_or_timestamp() { + let point = DataPoint::builder("m0") + .field("f0", 1.0) + .field("f1", 2_i64) + .build() + .unwrap(); + + assert_utf8_strings_eq( + &point.data_point_to_vec().unwrap(), + b"m0 f0=1,f1=2i\n".as_ref(), + ); + } + + #[test] + fn no_timestamp() { + let point = DataPoint::builder("m0") + .tag("t0", "v0") + .tag("t1", "v1") + .field("f1", 2_i64) + .build() + .unwrap(); + + assert_utf8_strings_eq( + &point.data_point_to_vec().unwrap(), + b"m0,t0=v0,t1=v1 f1=2i\n".as_ref(), + ); + } + + #[test] + fn no_field() { + let point_result = DataPoint::builder("m0").build(); + + assert!(point_result.is_err()); + } + + const ALL_THE_DELIMITERS: &str = r#"alpha,beta=delta gamma"epsilon"#; + + #[test] + fn special_characters_are_escaped_in_measurements() { + assert_utf8_strings_eq( + &ALL_THE_DELIMITERS.measurement_to_vec().unwrap(), + br#"alpha\,beta=delta\ gamma"epsilon"#.as_ref(), + ); + } + + #[test] + fn special_characters_are_escaped_in_tag_keys() { + assert_utf8_strings_eq( + &ALL_THE_DELIMITERS.tag_key_to_vec().unwrap(), + br#"alpha\,beta\=delta\ gamma"epsilon"#.as_ref(), + ); + } + + #[test] + fn special_characters_are_escaped_in_tag_values() { + assert_utf8_strings_eq( + &ALL_THE_DELIMITERS.tag_value_to_vec().unwrap(), + br#"alpha\,beta\=delta\ gamma"epsilon"#.as_ref(), + ); + } + + #[test] + fn special_characters_are_escaped_in_field_keys() { + assert_utf8_strings_eq( + &ALL_THE_DELIMITERS.field_key_to_vec().unwrap(), + br#"alpha\,beta\=delta\ gamma"epsilon"#.as_ref(), + ); + } + + #[test] + fn special_characters_are_escaped_in_field_values_of_strings() { + assert_utf8_strings_eq( + &FieldValue::from(ALL_THE_DELIMITERS) + .field_value_to_vec() + .unwrap(), + br#""alpha,beta=delta gamma\"epsilon""#.as_ref(), + ); + } + + #[test] + fn field_value_of_bool() { + let e = FieldValue::from(true); + assert_utf8_strings_eq(&e.field_value_to_vec().unwrap(), b"t"); + + let e = FieldValue::from(false); + assert_utf8_strings_eq(&e.field_value_to_vec().unwrap(), b"f"); + } + + #[test] + fn field_value_of_float() { + let e = FieldValue::from(42_f64); + assert_utf8_strings_eq(&e.field_value_to_vec().unwrap(), b"42"); + } + + #[test] + fn field_value_of_signed_integer() { + let e = FieldValue::from(42_i64); + assert_utf8_strings_eq(&e.field_value_to_vec().unwrap(), b"42i"); + } + + #[test] + fn field_value_of_unsigned_integer() { + let e = FieldValue::from(42_u64); + assert_utf8_strings_eq(&e.field_value_to_vec().unwrap(), b"42u"); + } + + #[test] + fn field_value_of_string() { + let e = FieldValue::from("hello"); + assert_utf8_strings_eq(&e.field_value_to_vec().unwrap(), br#""hello""#); + } + + // Clears up the boilerplate of writing to a vector from the tests + macro_rules! test_extension_traits { + ($($ext_name:ident :: $ext_fn_name:ident -> $base_name:ident :: $base_fn_name:ident,)*) => { + $( + trait $ext_name: $base_name { + fn $ext_fn_name(&self) -> io::Result> { + let mut v = Vec::new(); + self.$base_fn_name(&mut v)?; + Ok(v) + } + } + impl $ext_name for T {} + )* + } + } + + test_extension_traits! { + WriteDataPointExt::data_point_to_vec -> WriteDataPoint::write_data_point_to, + WriteMeasurementExt::measurement_to_vec -> WriteMeasurement::write_measurement_to, + WriteTagKeyExt::tag_key_to_vec -> WriteTagKey::write_tag_key_to, + WriteTagValueExt::tag_value_to_vec -> WriteTagValue::write_tag_value_to, + WriteFieldKeyExt::field_key_to_vec -> WriteFieldKey::write_field_key_to, + WriteFieldValueExt::field_value_to_vec -> WriteFieldValue::write_field_value_to, + } +} diff --git a/influxdb2_client/src/models/file.rs b/influxdb2_client/src/models/file.rs new file mode 100644 index 0000000..d0f0ab8 --- /dev/null +++ b/influxdb2_client/src/models/file.rs @@ -0,0 +1,30 @@ +//! File + +use serde::{Deserialize, Serialize}; + +/// Represents a source from a single file +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct File { + /// Type of AST node + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, + /// The name of the file. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + /// PackageClause + #[serde(skip_serializing_if = "Option::is_none")] + pub package: Option, + /// A list of package imports + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub imports: Vec, + /// List of Flux statements + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub body: Vec, +} + +impl File { + /// Represents a source from a single file + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/health.rs b/influxdb2_client/src/models/health.rs new file mode 100644 index 0000000..b9f46eb --- /dev/null +++ b/influxdb2_client/src/models/health.rs @@ -0,0 +1,49 @@ +//! Health + +use serde::{Deserialize, Serialize}; + +/// HealthCheck +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HealthCheck { + /// Name of the influxdb instance + pub name: String, + /// Message + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, + /// Checks + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub checks: Vec, + /// Status + pub status: Status, + /// Version + #[serde(skip_serializing_if = "Option::is_none")] + pub version: Option, + /// Commit + #[serde(skip_serializing_if = "Option::is_none")] + pub commit: Option, +} + +impl HealthCheck { + /// Returns instance of HealthCheck + pub fn new(name: String, status: Status) -> Self { + Self { + name, + status, + message: None, + checks: Vec::new(), + version: None, + commit: None, + } + } +} + +/// Status +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum Status { + /// Pass + Pass, + /// Fail + Fail, +} diff --git a/influxdb2_client/src/models/label.rs b/influxdb2_client/src/models/label.rs new file mode 100644 index 0000000..de11ae4 --- /dev/null +++ b/influxdb2_client/src/models/label.rs @@ -0,0 +1,111 @@ +//! Labels + +use serde::{Deserialize, Serialize}; + +/// Post create label request, to create a new label +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LabelCreateRequest { + /// Organisation ID + #[serde(rename = "orgID")] + pub org_id: String, + /// Label name + pub name: String, + /// Key/Value pairs associated with this label. + #[serde(skip_serializing_if = "Option::is_none")] + pub properties: Option<::std::collections::HashMap>, +} + +impl LabelCreateRequest { + /// Return instance of LabelCreateRequest + pub fn new(org_id: String, name: String) -> Self { + Self { + org_id, + name, + ..Default::default() + } + } +} + +/// LabelResponse +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LabelResponse { + /// Label + #[serde(skip_serializing_if = "Option::is_none")] + pub label: Option, + /// Links + #[serde(skip_serializing_if = "Option::is_none")] + pub links: Option, +} + +impl LabelResponse { + /// Returns instance of LabelResponse + pub fn new() -> Self { + Self::default() + } +} + +///LabelsResponse +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LabelsResponse { + /// Labels + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub labels: Vec, + /// Links + #[serde(skip_serializing_if = "Option::is_none")] + pub links: Option, +} + +impl LabelsResponse { + /// Returns List of Labels + pub fn new() -> Self { + Self::default() + } +} + +///LabelUpdateRequest +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LabelUpdate { + /// Name + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + /// Key/Value pairs associated with this label. + #[serde(skip_serializing_if = "Option::is_none")] + pub properties: Option<::std::collections::HashMap>, +} + +impl LabelUpdate { + /// Returns an instance of LabelUpdate + pub fn new() -> Self { + Self::default() + } +} + +/// Label +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Label { + /// Label ID + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + /// Org ID + #[serde(rename = "orgID", skip_serializing_if = "Option::is_none")] + pub org_id: Option, + /// Label name + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + /// Key/Value pairs associated with this label. Keys can be removed by + /// sending an update with an empty value. + #[serde(skip_serializing_if = "Option::is_none")] + pub properties: Option<::std::collections::HashMap>, +} + +impl Label { + /// Returns an instance of Label + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/links.rs b/influxdb2_client/src/models/links.rs new file mode 100644 index 0000000..6bf5871 --- /dev/null +++ b/influxdb2_client/src/models/links.rs @@ -0,0 +1,28 @@ +//! Links + +use serde::{Deserialize, Serialize}; + +/// Links +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Links { + /// Next link + #[serde(skip_serializing_if = "Option::is_none")] + pub next: Option, + /// Link to self + #[serde(rename = "self")] + pub self_: String, + /// Previous Link + #[serde(skip_serializing_if = "Option::is_none")] + pub prev: Option, +} + +impl Links { + /// Returns list of Links + pub fn new(self_: String) -> Self { + Self { + self_, + ..Default::default() + } + } +} diff --git a/influxdb2_client/src/models/mod.rs b/influxdb2_client/src/models/mod.rs new file mode 100644 index 0000000..8c5b4bc --- /dev/null +++ b/influxdb2_client/src/models/mod.rs @@ -0,0 +1,37 @@ +//! InfluxDB Models +//! +//! Roughly follows the OpenAPI specification + +pub mod ast; + +pub mod user; +pub use self::user::{User, UserLinks, Users, UsersLinks}; +pub mod organization; +pub use self::organization::{Organization, OrganizationLinks, Organizations}; +pub mod bucket; +pub use self::bucket::{Bucket, BucketLinks, Buckets, PostBucketRequest}; +pub mod onboarding; +pub use self::onboarding::{IsOnboarding, OnboardingRequest, OnboardingResponse}; +pub mod links; +pub use self::links::Links; +pub mod permission; +pub use self::permission::Permission; +pub mod label; +pub use self::label::{Label, LabelCreateRequest, LabelResponse, LabelUpdate, LabelsResponse}; +pub mod authorization; +pub use self::authorization::{Authorization, AuthorizationAllOfLinks}; +pub mod resource; +pub use self::resource::Resource; +pub mod retention_rule; +pub use self::retention_rule::RetentionRule; +pub mod query; +pub use self::query::{ + AnalyzeQueryResponse, AnalyzeQueryResponseErrors, AstResponse, FluxSuggestion, FluxSuggestions, + LanguageRequest, Query, +}; +pub mod file; +pub use self::file::File; +pub mod health; +pub use self::health::{HealthCheck, Status}; +pub mod data_point; +pub use data_point::{DataPoint, FieldValue, WriteDataPoint}; diff --git a/influxdb2_client/src/models/onboarding.rs b/influxdb2_client/src/models/onboarding.rs new file mode 100644 index 0000000..9c720c9 --- /dev/null +++ b/influxdb2_client/src/models/onboarding.rs @@ -0,0 +1,80 @@ +//! # Onboarding +//! +//! Initial setup of InfluxDB instance + +use serde::{Deserialize, Serialize}; + +/// Check if database has default user, org, bucket created, returns true if +/// not. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct IsOnboarding { + /// True if onboarding has already been completed otherwise false + #[serde(default)] + pub allowed: bool, +} + +impl IsOnboarding { + /// Return instance of IsOnboarding + pub fn new() -> Self { + Self::default() + } +} + +/// Post onboarding request, to setup initial user, org and bucket. +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct OnboardingRequest { + /// Initial username + pub username: String, + /// Initial organization name + pub org: String, + /// Initial bucket name + pub bucket: String, + /// Initial password of user + #[serde(skip_serializing_if = "Option::is_none")] + pub password: Option, + /// Retention period in nanoseconds + #[serde(skip_serializing_if = "Option::is_none")] + pub retention_period_seconds: Option, + /// Retention period *in nanoseconds* for the new bucket. This key's name + /// has been misleading since OSS 2.0 GA, please transition to use + /// `retentionPeriodSeconds` + #[serde(skip_serializing_if = "Option::is_none")] + pub retention_period_hrs: Option, +} + +impl OnboardingRequest { + /// Return instance of OnboardingRequest + pub fn new(username: String, org: String, bucket: String) -> Self { + Self { + username, + org, + bucket, + ..Default::default() + } + } +} + +/// OnboardingResponse +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct OnboardingResponse { + /// User + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + /// Organization + #[serde(skip_serializing_if = "Option::is_none")] + pub org: Option, + /// Bucket + #[serde(skip_serializing_if = "Option::is_none")] + pub bucket: Option, + /// Auth token + #[serde(skip_serializing_if = "Option::is_none")] + pub auth: Option, +} + +impl OnboardingResponse { + /// Return instance of OnboardingResponse + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/organization.rs b/influxdb2_client/src/models/organization.rs new file mode 100644 index 0000000..96c81ff --- /dev/null +++ b/influxdb2_client/src/models/organization.rs @@ -0,0 +1,103 @@ +//! Organization + +use serde::{Deserialize, Serialize}; + +/// Organization Schema +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Organization { + /// Links + #[serde(skip_serializing_if = "Option::is_none")] + pub links: Option, + /// Organization ID + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + /// Organization Name + pub name: String, + /// Organization description + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Organization created timestamp + #[serde(skip_serializing_if = "Option::is_none")] + pub created_at: Option, + /// Organization updated timestamp + #[serde(skip_serializing_if = "Option::is_none")] + pub updated_at: Option, + /// If inactive the organization is inactive. + #[serde(skip_serializing_if = "Option::is_none")] + pub status: Option, +} + +impl Organization { + /// Returns instance of Organization + pub fn new(name: String) -> Self { + Self { + name, + ..Default::default() + } + } +} + +/// If inactive the organization is inactive. +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum Status { + /// Organization is active + Active, + /// Organization is inactive + Inactive, +} + +/// Organization Links +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct OrganizationLinks { + /// Link to self + #[serde(rename = "self", skip_serializing_if = "Option::is_none")] + pub self_: Option, + /// Links to members + #[serde(skip_serializing_if = "Option::is_none")] + pub members: Option, + /// Links to owners + #[serde(skip_serializing_if = "Option::is_none")] + pub owners: Option, + /// Links to labels + #[serde(skip_serializing_if = "Option::is_none")] + pub labels: Option, + /// Links to secrets + #[serde(skip_serializing_if = "Option::is_none")] + pub secrets: Option, + /// Links to buckets + #[serde(skip_serializing_if = "Option::is_none")] + pub buckets: Option, + /// Links to tasks + #[serde(skip_serializing_if = "Option::is_none")] + pub tasks: Option, + /// Links to dashboards + #[serde(skip_serializing_if = "Option::is_none")] + pub dashboards: Option, +} + +impl OrganizationLinks { + /// Returns instance of Organization Links + pub fn new() -> Self { + Self::default() + } +} + +/// Organizations +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct Organizations { + /// Links + #[serde(skip_serializing_if = "Option::is_none")] + pub links: Option, + /// List of organizations + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub orgs: Vec, +} + +impl Organizations { + /// Returns instance of Organizations + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/src/models/permission.rs b/influxdb2_client/src/models/permission.rs new file mode 100644 index 0000000..7791d1b --- /dev/null +++ b/influxdb2_client/src/models/permission.rs @@ -0,0 +1,29 @@ +//! Permissions + +use serde::{Deserialize, Serialize}; + +/// Permissions for a resource +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct Permission { + /// Access Type + pub action: Action, + /// Resource object + pub resource: crate::models::Resource, +} + +impl Permission { + /// Return instance of Permission + pub fn new(action: Action, resource: crate::models::Resource) -> Self { + Self { action, resource } + } +} + +/// Allowed Permission Action +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum Action { + /// Read access + Read, + /// Write access + Write, +} diff --git a/influxdb2_client/src/models/query.rs b/influxdb2_client/src/models/query.rs new file mode 100644 index 0000000..cc5c7a2 --- /dev/null +++ b/influxdb2_client/src/models/query.rs @@ -0,0 +1,161 @@ +//! Query + +use crate::models::ast::Package; +use crate::models::File; +use serde::{Deserialize, Serialize}; +use serde_json::Number; +use std::collections::HashMap; + +/// Query influx using the Flux language +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct Query { + /// Query Script + #[serde(rename = "extern", skip_serializing_if = "Option::is_none")] + pub r#extern: Option, + /// Query script to execute. + pub query: String, + /// The type of query. Must be \"flux\". + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, + /// Dialect + #[serde(skip_serializing_if = "Option::is_none")] + pub dialect: Option, + /// Specifies the time that should be reported as "now" in the query. + /// Default is the server's now time. + #[serde(skip_serializing_if = "Option::is_none")] + pub now: Option, + + /// Params for use in query via params.param_name + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option>, +} + +/// Query Param Enum for Flux +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[serde(untagged)] +pub enum Param { + /// A number param + Number(Number), + /// A string param + String(String), +} + +impl Query { + /// Query influx using the Flux language + pub fn new(query: String) -> Self { + Self { + query, + ..Default::default() + } + } +} + +/// The type of query. Must be \"flux\". +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum Type { + /// Query Type + Flux, +} + +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +/// Flux Query Suggestion +pub struct FluxSuggestion { + /// Suggestion Name + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + /// Suggestion Params + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option>, +} + +impl FluxSuggestion { + /// Returns an instance FluxSuggestion + pub fn new() -> Self { + Self::default() + } +} + +/// FluxSuggestions +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct FluxSuggestions { + /// List of Flux Suggestions + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub funcs: Vec, +} + +impl FluxSuggestions { + /// Return an instance of FluxSuggestions + pub fn new() -> Self { + Self::default() + } +} + +/// AnalyzeQueryResponse +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct AnalyzeQueryResponse { + /// List of QueryResponseErrors + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub errors: Vec, +} + +impl AnalyzeQueryResponse { + /// Return an instance of AnanlyzeQueryResponse + pub fn new() -> Self { + Self::default() + } +} + +/// AnalyzeQueryResponseErrors +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct AnalyzeQueryResponseErrors { + /// Error line + #[serde(skip_serializing_if = "Option::is_none")] + pub line: Option, + /// Error column + #[serde(skip_serializing_if = "Option::is_none")] + pub column: Option, + /// Error char + #[serde(skip_serializing_if = "Option::is_none")] + pub character: Option, + /// Error message + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} + +impl AnalyzeQueryResponseErrors { + /// Return an instance of AnalyzeQueryResponseErrors + pub fn new() -> Self { + Self::default() + } +} + +/// AstResponse : Contains the AST for the supplied Flux query +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct AstResponse { + /// AST of Flux query + #[serde(skip_serializing_if = "Option::is_none")] + pub ast: Option, +} + +impl AstResponse { + /// Contains the AST for the supplied Flux query + pub fn new() -> Self { + Self::default() + } +} + +/// LanguageRequest : Flux query to be analyzed. +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct LanguageRequest { + /// Flux query script to be analyzed + pub query: String, +} + +impl LanguageRequest { + /// Flux query to be analyzed. + pub fn new(query: String) -> Self { + Self { query } + } +} diff --git a/influxdb2_client/src/models/resource.rs b/influxdb2_client/src/models/resource.rs new file mode 100644 index 0000000..1b1e0e6 --- /dev/null +++ b/influxdb2_client/src/models/resource.rs @@ -0,0 +1,82 @@ +//! Resources + +use serde::{Deserialize, Serialize}; + +/// Construct a resource +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Resource { + /// Resource Type + #[serde(rename = "type")] + pub r#type: Type, + /// If ID is set that is a permission for a specific resource. if it is not + /// set it is a permission for all resources of that resource type. + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + /// Optional name of the resource if the resource has a name field. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + /// If orgID is set that is a permission for all resources owned my that + /// org. if it is not set it is a permission for all resources of that + /// resource type. + #[serde(rename = "orgID", skip_serializing_if = "Option::is_none")] + pub org_id: Option, + /// Optional name of the organization of the organization with orgID. + #[serde(skip_serializing_if = "Option::is_none")] + pub org: Option, +} + +impl Resource { + /// Returns instance of Resource + pub fn new(r#type: Type) -> Self { + Self { + r#type, + id: None, + name: None, + org_id: None, + org: None, + } + } +} + +/// Resource Type +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum Type { + /// Authorizations + Authorizations, + /// Buckets + Buckets, + /// Dashboards + Dashboards, + /// Organizations + Orgs, + /// Sources + Sources, + /// Tasks + Tasks, + /// Telegrafs + Telegrafs, + /// Users + Users, + /// Variables + Variables, + /// Scrapers + Scrapers, + /// Secrets + Secrets, + /// Labels + Labels, + /// Views + Views, + /// Documents + Documents, + /// Notification Rules + NotificationRules, + /// Notification Endpoints + NotificationEndpoints, + /// Checks + Checks, + /// DBRP + Dbrp, +} diff --git a/influxdb2_client/src/models/retention_rule.rs b/influxdb2_client/src/models/retention_rule.rs new file mode 100644 index 0000000..ee78609 --- /dev/null +++ b/influxdb2_client/src/models/retention_rule.rs @@ -0,0 +1,37 @@ +//! Retention Rules + +use serde::{Deserialize, Serialize}; + +/// RetentionRule +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RetentionRule { + /// Expiry + #[serde(rename = "type")] + pub r#type: Type, + /// Duration in seconds for how long data will be kept in the database. 0 + /// means infinite. + pub every_seconds: i32, + /// Shard duration measured in seconds. + #[serde(skip_serializing_if = "Option::is_none")] + pub shard_group_duration_seconds: Option, +} + +impl RetentionRule { + /// Returns instance of RetentionRule + pub fn new(r#type: Type, every_seconds: i32) -> Self { + Self { + r#type, + every_seconds, + shard_group_duration_seconds: None, + } + } +} + +/// Set Retention Rule expired or not +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum Type { + /// RetentionRule Expired + Expire, +} diff --git a/influxdb2_client/src/models/user.rs b/influxdb2_client/src/models/user.rs new file mode 100644 index 0000000..f0d8074 --- /dev/null +++ b/influxdb2_client/src/models/user.rs @@ -0,0 +1,90 @@ +//! Users + +use serde::{Deserialize, Serialize}; + +/// User Schema +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct User { + /// User ID + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + /// User oauth token id + #[serde(rename = "oauthID", skip_serializing_if = "Option::is_none")] + pub oauth_id: Option, + /// User name + pub name: String, + /// If inactive the user is inactive. + #[serde(skip_serializing_if = "Option::is_none")] + pub status: Option, + /// User links + #[serde(skip_serializing_if = "Option::is_none")] + pub links: Option, +} + +impl User { + /// Returns instance of user + pub fn new(name: String) -> Self { + Self { + name, + ..Default::default() + } + } +} + +/// If inactive the user is inactive. +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum Status { + /// User is active + Active, + /// User is inactive + Inactive, +} + +/// User links +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct UserLinks { + /// User link to Self + #[serde(rename = "self", skip_serializing_if = "Option::is_none")] + pub self_: Option, +} + +impl UserLinks { + /// Returns instance of UserLinks + pub fn new() -> Self { + Self::default() + } +} + +/// List of Users +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct Users { + /// List of user links + #[serde(skip_serializing_if = "Option::is_none")] + pub links: Option, + /// List of users + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub users: Vec, +} + +impl Users { + /// Returns instance of Users + pub fn new() -> Self { + Self::default() + } +} + +/// UsersLinks +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct UsersLinks { + /// Users Link to Self + #[serde(rename = "self", skip_serializing_if = "Option::is_none")] + pub self_: Option, +} + +impl UsersLinks { + /// Returns instance of UsersLinks + pub fn new() -> Self { + Self::default() + } +} diff --git a/influxdb2_client/tests/common/mod.rs b/influxdb2_client/tests/common/mod.rs new file mode 100644 index 0000000..a057eb4 --- /dev/null +++ b/influxdb2_client/tests/common/mod.rs @@ -0,0 +1 @@ +pub mod server_fixture; diff --git a/influxdb2_client/tests/common/server_fixture.rs b/influxdb2_client/tests/common/server_fixture.rs new file mode 100644 index 0000000..e8707f2 --- /dev/null +++ b/influxdb2_client/tests/common/server_fixture.rs @@ -0,0 +1,358 @@ +use once_cell::sync::OnceCell; +use std::{ + fs::File, + process::{Child, Command, Stdio}, + sync::{ + atomic::{AtomicUsize, Ordering::SeqCst}, + Arc, Weak, + }, + time::Duration, +}; +use tokio::sync::Mutex; + +#[macro_export] +/// If `TEST_INTEGRATION` is set and InfluxDB 2.0 OSS is available (either locally +/// via `influxd` directly if the `INFLUXDB_IOX_INTEGRATION_LOCAL` environment +// variable is set, or via `docker` otherwise), set up the server as requested and +/// return it to the caller. +/// +/// If `TEST_INTEGRATION` is not set, skip the calling test by returning early. +macro_rules! maybe_skip_integration { + ($server_fixture:expr) => {{ + let local = std::env::var("INFLUXDB_IOX_INTEGRATION_LOCAL").is_ok(); + let command = if local { "influxd" } else { "docker" }; + + match ( + std::process::Command::new("which") + .arg(command) + .stdout(std::process::Stdio::null()) + .status() + .expect("should be able to run `which`") + .success(), + std::env::var("TEST_INTEGRATION").is_ok(), + ) { + (true, true) => $server_fixture, + (false, true) => { + panic!( + "TEST_INTEGRATION is set which requires running integration tests, but \ + `{}` is not available", + command + ) + } + _ => { + eprintln!( + "skipping integration test - set the TEST_INTEGRATION environment variable \ + and install `{}` to run", + command + ); + return Ok(()); + } + } + }}; +} + +/// Represents a server that has been started and is available for +/// testing. +#[derive(Debug)] +pub struct ServerFixture { + server: Arc, +} + +impl ServerFixture { + /// Create a new server fixture and wait for it to be ready. This + /// is called "create" rather than new because it is async and + /// waits. The shared database can be used immediately. + /// + /// This is currently implemented as a singleton so all tests *must* + /// use a new database and not interfere with the existing database. + pub async fn create_shared() -> Self { + // Try and reuse the same shared server, if there is already + // one present + static SHARED_SERVER: OnceCell>> = OnceCell::new(); + + let shared_server = SHARED_SERVER.get_or_init(|| parking_lot::Mutex::new(Weak::new())); + + let shared_upgraded = { + let locked = shared_server.lock(); + locked.upgrade() + }; + + // is a shared server already present? + let server = match shared_upgraded { + Some(server) => server, + None => { + // if not, create one + let mut server = TestServer::new(); + // ensure the server is ready + server.wait_until_ready(InitialConfig::Onboarded).await; + + let server = Arc::new(server); + // save a reference for other threads that may want to + // use this server, but don't prevent it from being + // destroyed when going out of scope + let mut shared_server = shared_server.lock(); + *shared_server = Arc::downgrade(&server); + server + } + }; + + Self { server } + } + + /// Create a new server fixture and wait for it to be ready. This + /// is called "create" rather than new because it is async and + /// waits. The database is left unconfigured and is not shared + /// with any other tests. + pub async fn create_single_use() -> Self { + let mut server = TestServer::new(); + + // ensure the server is ready + server.wait_until_ready(InitialConfig::None).await; + + let server = Arc::new(server); + + Self { server } + } + + /// Return a client suitable for communicating with this server + pub fn client(&self) -> influxdb2_client::Client { + match self.server.admin_token.as_ref() { + Some(token) => influxdb2_client::Client::new(self.http_base(), token), + None => influxdb2_client::Client::new(self.http_base(), ""), + } + } + + /// Return the http base URL for the HTTP API + pub fn http_base(&self) -> &str { + &self.server.http_base + } +} + +/// Specifies whether the server should be set up initially +#[derive(Debug, Copy, Clone, PartialEq)] +enum InitialConfig { + /// Don't set up the server, the test will (for testing onboarding) + None, + /// Onboard the server and set up the client with the associated token (for + /// most tests) + Onboarded, +} + +// These port numbers are chosen to not collide with a development ioxd/influxd +// server running locally. +// TODO(786): allocate random free ports instead of hardcoding. +// TODO(785): we cannot use localhost here. +static NEXT_PORT: AtomicUsize = AtomicUsize::new(8190); + +/// Represents the current known state of a TestServer +#[derive(Debug)] +enum ServerState { + Started, + Ready, + Error, +} + +const ADMIN_TEST_USER: &str = "admin-test-user"; +const ADMIN_TEST_ORG: &str = "admin-test-org"; +const ADMIN_TEST_BUCKET: &str = "admin-test-bucket"; +const ADMIN_TEST_PASSWORD: &str = "admin-test-password"; + +#[derive(Debug)] +struct TestServer { + /// Is the server ready to accept connections? + ready: Mutex, + /// Handle to the server process being controlled + server_process: Child, + /// When using Docker, the name of the detached child + docker_name: Option, + /// HTTP API base + http_base: String, + /// Admin token, if onboarding has happened + admin_token: Option, +} + +impl TestServer { + fn new() -> Self { + let ready = Mutex::new(ServerState::Started); + let http_port = NEXT_PORT.fetch_add(1, SeqCst); + let http_base = format!("http://127.0.0.1:{http_port}"); + + let temp_dir = test_helpers::tmp_dir().unwrap(); + + let mut log_path = temp_dir.path().to_path_buf(); + log_path.push(format!("influxdb_server_fixture_{http_port}.log")); + + let mut bolt_path = temp_dir.path().to_path_buf(); + bolt_path.push(format!("influxd_{http_port}.bolt")); + + let mut engine_path = temp_dir.path().to_path_buf(); + engine_path.push(format!("influxd_{http_port}_engine")); + + println!("****************"); + println!("Server Logging to {log_path:?}"); + println!("****************"); + let log_file = File::create(log_path).expect("Opening log file"); + + let stdout_log_file = log_file + .try_clone() + .expect("cloning file handle for stdout"); + let stderr_log_file = log_file; + + let local = std::env::var("INFLUXDB_IOX_INTEGRATION_LOCAL").is_ok(); + + let (server_process, docker_name) = if local { + let cmd = Command::new("influxd") + .arg("--http-bind-address") + .arg(format!(":{http_port}")) + .arg("--bolt-path") + .arg(bolt_path) + .arg("--engine-path") + .arg(engine_path) + // redirect output to log file + .stdout(stdout_log_file) + .stderr(stderr_log_file) + .spawn() + .expect("starting of local server process"); + (cmd, None) + } else { + let ci_image = "quay.io/influxdb/rust:ci"; + let container_name = format!("influxdb2_{http_port}"); + + Command::new("docker") + .arg("container") + .arg("run") + .arg("--name") + .arg(&container_name) + .arg("--publish") + .arg(format!("{http_port}:8086")) + .arg("--rm") + .arg("--pull") + .arg("always") + .arg("--detach") + .arg(ci_image) + .arg("influxd") + .output() + .expect("starting of docker server process"); + + let cmd = Command::new("docker") + .arg("logs") + .arg(&container_name) + // redirect output to log file + .stdout(stdout_log_file) + .stderr(stderr_log_file) + .spawn() + .expect("starting of docker logs process"); + + (cmd, Some(container_name)) + }; + + Self { + ready, + server_process, + docker_name, + http_base, + admin_token: None, + } + } + + async fn wait_until_ready(&mut self, initial_config: InitialConfig) { + let mut ready = self.ready.lock().await; + match *ready { + ServerState::Started => {} // first time, need to try and start it + ServerState::Ready => { + return; + } + ServerState::Error => { + panic!("Server was previously found to be in Error, aborting"); + } + } + + let try_http_connect = async { + let client = reqwest::Client::new(); + let url = format!("{}/health", self.http_base); + let mut interval = tokio::time::interval(Duration::from_secs(5)); + loop { + match client.get(&url).send().await { + Ok(resp) => { + println!("Successfully got a response from HTTP: {resp:?}"); + return; + } + Err(e) => { + println!("Waiting for HTTP server to be up: {e}"); + } + } + interval.tick().await; + } + }; + + let capped_check = tokio::time::timeout(Duration::from_secs(100), try_http_connect); + + match capped_check.await { + Ok(_) => { + println!("Successfully started {self}"); + *ready = ServerState::Ready; + } + Err(e) => { + // tell others that this server had some problem + *ready = ServerState::Error; + std::mem::drop(ready); + panic!("Server was not ready in required time: {e}"); + } + } + + // Onboard, if requested. + if initial_config == InitialConfig::Onboarded { + let client = influxdb2_client::Client::new(&self.http_base, ""); + let response = client + .onboarding( + ADMIN_TEST_USER, + ADMIN_TEST_ORG, + ADMIN_TEST_BUCKET, + Some(ADMIN_TEST_PASSWORD.to_string()), + Some(0), + None, + ) + .await; + + match response { + Ok(onboarding) => { + let token = onboarding + .auth + .expect("Onboarding should have returned auth info") + .token + .expect("Onboarding auth should have returned a token"); + self.admin_token = Some(token); + } + Err(e) => { + *ready = ServerState::Error; + std::mem::drop(ready); + panic!("Could not onboard: {e}"); + } + } + } + } +} + +impl std::fmt::Display for TestServer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + write!(f, "TestServer (http api: {})", self.http_base) + } +} + +impl Drop for TestServer { + fn drop(&mut self) { + self.server_process + .kill() + .expect("Should have been able to kill the test server"); + + if let Some(docker_name) = &self.docker_name { + Command::new("docker") + .arg("rm") + .arg("--force") + .arg(docker_name) + .stdout(Stdio::null()) + .status() + .expect("killing of docker process"); + } + } +} diff --git a/influxdb2_client/tests/health.rs b/influxdb2_client/tests/health.rs new file mode 100644 index 0000000..293550e --- /dev/null +++ b/influxdb2_client/tests/health.rs @@ -0,0 +1,17 @@ +pub mod common; +use common::server_fixture::ServerFixture; + +type Result> = std::result::Result; + +#[tokio::test] +async fn get_health() -> Result { + // Using a server that has been set up + let server_fixture = maybe_skip_integration!(ServerFixture::create_shared()).await; + let client = server_fixture.client(); + + let res = client.health().await?; + + assert_eq!(res.status, influxdb2_client::models::Status::Pass); + + Ok(()) +} diff --git a/influxdb2_client/tests/setup.rs b/influxdb2_client/tests/setup.rs new file mode 100644 index 0000000..1ba03f9 --- /dev/null +++ b/influxdb2_client/tests/setup.rs @@ -0,0 +1,118 @@ +pub mod common; +use common::server_fixture::ServerFixture; + +type Result> = std::result::Result; + +#[tokio::test] +async fn new_server_needs_onboarded() -> Result { + let server_fixture = maybe_skip_integration!(ServerFixture::create_single_use()).await; + let client = server_fixture.client(); + + let res = client.is_onboarding_allowed().await?; + assert!(res); + + // Creating a new setup user without first onboarding is an error + let username = "some-user"; + let org = "some-org"; + let bucket = "some-bucket"; + let password = "some-password"; + let retention_period_hrs = 0; + + let err = client + .post_setup_user( + username, + org, + bucket, + Some(password.to_string()), + Some(retention_period_hrs), + None, + ) + .await + .expect_err("Expected error, got success"); + + assert!(matches!( + err, + influxdb2_client::RequestError::Http { + status: reqwest::StatusCode::UNAUTHORIZED, + .. + } + )); + + Ok(()) +} + +#[tokio::test] +async fn onboarding() -> Result { + let server_fixture = maybe_skip_integration!(ServerFixture::create_single_use()).await; + let client = server_fixture.client(); + + let username = "some-user"; + let org = "some-org"; + let bucket = "some-bucket"; + let password = "some-password"; + let retention_period_hrs = 0; + + client + .onboarding( + username, + org, + bucket, + Some(password.to_string()), + Some(retention_period_hrs), + None, + ) + .await?; + + let res = client.is_onboarding_allowed().await?; + assert!(!res); + + // Onboarding twice is an error + let err = client + .onboarding( + username, + org, + bucket, + Some(password.to_string()), + Some(retention_period_hrs), + None, + ) + .await + .expect_err("Expected error, got success"); + + assert!(matches!( + err, + influxdb2_client::RequestError::Http { + status: reqwest::StatusCode::UNPROCESSABLE_ENTITY, + .. + } + )); + + Ok(()) +} + +#[tokio::test] +async fn create_users() -> Result { + // Using a server that has been set up + let server_fixture = maybe_skip_integration!(ServerFixture::create_shared()).await; + let client = server_fixture.client(); + + let username = "another-user"; + let org = "another-org"; + let bucket = "another-bucket"; + let password = "another-password"; + let retention_period_hrs = 0; + + // Creating a user should work + client + .post_setup_user( + username, + org, + bucket, + Some(password.to_string()), + Some(retention_period_hrs), + None, + ) + .await?; + + Ok(()) +} diff --git a/influxdb_influxql_parser/Cargo.toml b/influxdb_influxql_parser/Cargo.toml new file mode 100644 index 0000000..6751e45 --- /dev/null +++ b/influxdb_influxql_parser/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "influxdb_influxql_parser" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[lints] +workspace = true + +[dependencies] # In alphabetical order +nom = { version = "7", default-features = false, features = ["std"] } +once_cell = "1" +chrono = { version = "0.4", default-features = false, features = ["std"] } +chrono-tz = { version = "0.8" } +num-integer = { version = "0.1", default-features = false, features = ["i128", "std"] } +num-traits = "0.2" +workspace-hack = { version = "0.1", path = "../workspace-hack" } + +[dev-dependencies] # In alphabetical order +test_helpers = { path = "../test_helpers" } +assert_matches = "1" +insta = { version = "1.34.0", features = ["yaml"] } +paste = "1.0.14" diff --git a/influxdb_influxql_parser/src/common.rs b/influxdb_influxql_parser/src/common.rs new file mode 100644 index 0000000..e3788c5 --- /dev/null +++ b/influxdb_influxql_parser/src/common.rs @@ -0,0 +1,1014 @@ +//! Type and parsers common to many statements. + +use crate::expression::conditional::{conditional_expression, ConditionalExpression}; +use crate::identifier::{identifier, Identifier}; +use crate::internal::{expect, verify, ParseResult}; +use crate::keywords::{keyword, Token}; +use crate::literal::unsigned_integer; +use crate::string::{regex, Regex}; +use core::fmt; +use nom::branch::alt; +use nom::bytes::complete::{tag, take_till, take_until}; +use nom::character::complete::{char, multispace1}; +use nom::combinator::{map, opt, recognize, value}; +use nom::multi::{fold_many0, fold_many1, separated_list1}; +use nom::sequence::{delimited, pair, preceded, terminated}; +use std::fmt::{Display, Formatter}; +use std::mem; +use std::ops::{Deref, DerefMut}; + +/// A error returned when parsing an InfluxQL query, expressions. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ParseError { + pub(crate) message: String, + pub(crate) pos: usize, +} + +impl Display for ParseError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{} at pos {}", self.message, self.pos) + } +} + +/// Represents a measurement name as either an identifier or a regular expression. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum MeasurementName { + /// A measurement name expressed as an [`Identifier`]. + Name(Identifier), + + /// A measurement name expressed as a [`Regex`]. + Regex(Regex), +} + +impl Parser for MeasurementName { + /// Parse a measurement name, which may be an identifier or a regular expression. + fn parse(i: &str) -> ParseResult<&str, Self> { + alt(( + map(identifier, MeasurementName::Name), + map(regex, MeasurementName::Regex), + ))(i) + } +} + +impl Display for MeasurementName { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Name(ident) => fmt::Display::fmt(ident, f), + Self::Regex(regex) => fmt::Display::fmt(regex, f), + } + } +} + +/// Represents a fully-qualified, 3-part measurement name. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct QualifiedMeasurementName { + /// An optional database name. + pub database: Option, + + /// An optional retention policy. + pub retention_policy: Option, + + /// The measurement name. + pub name: MeasurementName, +} + +impl Display for QualifiedMeasurementName { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self { + database: None, + retention_policy: None, + name, + } => write!(f, "{name}"), + Self { + database: Some(db), + retention_policy: None, + name, + } => write!(f, "{db}..{name}"), + Self { + database: None, + retention_policy: Some(rp), + name, + } => write!(f, "{rp}.{name}"), + Self { + database: Some(db), + retention_policy: Some(rp), + name, + } => write!(f, "{db}.{rp}.{name}"), + } + } +} + +/// Match a fully-qualified, 3-part measurement name. +/// +/// ```text +/// qualified_measurement_name ::= measurement_name | +/// ( policy_name "." measurement_name ) | +/// ( db_name "." policy_name? "." measurement_name ) +/// +/// db_name ::= identifier +/// policy_name ::= identifier +/// measurement_name ::= identifier | regex_lit +/// ``` +pub(crate) fn qualified_measurement_name(i: &str) -> ParseResult<&str, QualifiedMeasurementName> { + let (remaining_input, (opt_db_rp, name)) = pair( + opt(alt(( + // database "." retention_policy "." + map( + pair( + terminated(identifier, tag(".")), + terminated(identifier, tag(".")), + ), + |(db, rp)| (Some(db), Some(rp)), + ), + // database ".." + map(terminated(identifier, tag("..")), |db| (Some(db), None)), + // retention_policy "." + map(terminated(identifier, tag(".")), |rp| (None, Some(rp))), + ))), + MeasurementName::parse, + )(i)?; + + // Extract possible `database` and / or `retention_policy` + let (database, retention_policy) = match opt_db_rp { + Some(db_rp) => db_rp, + _ => (None, None), + }; + + Ok(( + remaining_input, + QualifiedMeasurementName { + database, + retention_policy, + name, + }, + )) +} + +/// Parse a SQL-style single-line comment +fn comment_single_line(i: &str) -> ParseResult<&str, &str> { + recognize(pair(tag("--"), take_till(|c| c == '\n' || c == '\r')))(i) +} + +/// Parse a SQL-style inline comment, which can span multiple lines +fn comment_inline(i: &str) -> ParseResult<&str, &str> { + recognize(delimited( + tag("/*"), + expect( + "invalid inline comment, missing closing */", + take_until("*/"), + ), + tag("*/"), + ))(i) +} + +/// Repeats the embedded parser until it fails, discarding the results. +/// +/// This parser is used as a non-allocating version of [`nom::multi::many0`]. +fn many0_<'a, A, F>(mut f: F) -> impl FnMut(&'a str) -> ParseResult<&'a str, ()> +where + F: FnMut(&'a str) -> ParseResult<&'a str, A>, +{ + move |i| fold_many0(&mut f, || (), |_, _| ())(i) +} + +/// Optionally consume all whitespace, single-line or inline comments +pub(crate) fn ws0(i: &str) -> ParseResult<&str, ()> { + many0_(alt((multispace1, comment_single_line, comment_inline)))(i) +} + +/// Runs the embedded parser until it fails, discarding the results. +/// Fails if the embedded parser does not produce at least one result. +/// +/// This parser is used as a non-allocating version of [`nom::multi::many1`]. +fn many1_<'a, A, F>(mut f: F) -> impl FnMut(&'a str) -> ParseResult<&'a str, ()> +where + F: FnMut(&'a str) -> ParseResult<&'a str, A>, +{ + move |i| fold_many1(&mut f, || (), |_, _| ())(i) +} + +/// Must consume either whitespace, single-line or inline comments +pub(crate) fn ws1(i: &str) -> ParseResult<&str, ()> { + many1_(alt((multispace1, comment_single_line, comment_inline)))(i) +} + +/// Implements common behaviour for u64 tuple-struct types +#[macro_export] +macro_rules! impl_tuple_clause { + ($NAME:ident, $FOR:ty) => { + impl $NAME { + /// Create a new instance with the specified value. + pub fn new(value: $FOR) -> Self { + Self(value) + } + } + + impl std::ops::DerefMut for $NAME { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + + impl std::ops::Deref for $NAME { + type Target = $FOR; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl From<$FOR> for $NAME { + fn from(value: $FOR) -> Self { + Self(value) + } + } + }; +} + +/// Represents the value for a `LIMIT` clause. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct LimitClause(pub(crate) u64); + +impl_tuple_clause!(LimitClause, u64); + +impl Display for LimitClause { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "LIMIT {}", self.0) + } +} + +/// Parse a `LIMIT ` clause. +pub(crate) fn limit_clause(i: &str) -> ParseResult<&str, LimitClause> { + preceded( + pair(keyword("LIMIT"), ws1), + expect( + "invalid LIMIT clause, expected unsigned integer", + map(unsigned_integer, LimitClause), + ), + )(i) +} + +/// Represents the value for a `OFFSET` clause. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct OffsetClause(pub(crate) u64); + +impl_tuple_clause!(OffsetClause, u64); + +impl Display for OffsetClause { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "OFFSET {}", self.0) + } +} + +/// Parse an `OFFSET ` clause. +pub(crate) fn offset_clause(i: &str) -> ParseResult<&str, OffsetClause> { + preceded( + pair(keyword("OFFSET"), ws1), + expect( + "invalid OFFSET clause, expected unsigned integer", + map(unsigned_integer, OffsetClause), + ), + )(i) +} + +/// Parse a terminator that ends a SQL statement. +pub(crate) fn statement_terminator(i: &str) -> ParseResult<&str, ()> { + value((), char(';'))(i) +} + +/// Represents the `WHERE` clause of a statement. +#[derive(Debug, Clone, PartialEq)] +pub struct WhereClause(pub(crate) ConditionalExpression); + +impl WhereClause { + /// Create an instance of a `WhereClause` using `expr` + pub fn new(expr: ConditionalExpression) -> Self { + Self(expr) + } +} + +impl DerefMut for WhereClause { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Deref for WhereClause { + type Target = ConditionalExpression; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Display for WhereClause { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "WHERE {}", self.0) + } +} + +/// Parse a `WHERE` clause. +pub(crate) fn where_clause(i: &str) -> ParseResult<&str, WhereClause> { + preceded( + pair(keyword("WHERE"), ws0), + map(conditional_expression, WhereClause), + )(i) +} + +/// Represents an InfluxQL `ORDER BY` clause. +#[derive(Default, Debug, Clone, Copy, Eq, PartialEq)] +pub enum OrderByClause { + /// Signals the `ORDER BY` is in ascending order. + #[default] + Ascending, + + /// Signals the `ORDER BY` is in descending order. + Descending, +} + +impl OrderByClause { + /// Return `true` if the order by clause is ascending. + pub fn is_ascending(self) -> bool { + matches!(self, Self::Ascending) + } +} + +impl Display for OrderByClause { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ORDER BY TIME {}", + match self { + Self::Ascending => "ASC", + Self::Descending => "DESC", + } + ) + } +} + +/// Parse an InfluxQL `ORDER BY` clause. +/// +/// An `ORDER BY` in InfluxQL is limited when compared to the equivalent +/// SQL definition. It is defined by the following [EBNF] notation: +/// +/// ```text +/// order_by ::= "ORDER" "BY" (time_order | order) +/// order ::= "ASC | "DESC +/// time_order ::= "TIME" order? +/// ``` +/// +/// Resulting in the following valid strings: +/// +/// ```text +/// ORDER BY ASC +/// ORDER BY DESC +/// ORDER BY time +/// ORDER BY time ASC +/// ORDER BY time DESC +/// ``` +/// +/// [EBNF]: https://www.w3.org/TR/2010/REC-xquery-20101214/#EBNFNotation +pub(crate) fn order_by_clause(i: &str) -> ParseResult<&str, OrderByClause> { + let order = || { + preceded( + ws1, + alt(( + value(OrderByClause::Ascending, keyword("ASC")), + value(OrderByClause::Descending, keyword("DESC")), + )), + ) + }; + + preceded( + // "ORDER" "BY" + pair(keyword("ORDER"), preceded(ws1, keyword("BY"))), + expect( + "invalid ORDER BY, expected ASC, DESC or TIME", + alt(( + // "ASC" | "DESC" + order(), + // "TIME" ( "ASC" | "DESC" )? + map( + preceded( + preceded( + ws1, + verify("invalid ORDER BY, expected TIME column", identifier, |v| { + Token(&v.0) == Token("time") + }), + ), + opt(order()), + ), + Option::<_>::unwrap_or_default, + ), + )), + ), + )(i) +} + +/// Parser is a trait that allows a type to parse itself. +pub trait Parser: Sized { + /// Parse this type from the string `i`. + fn parse(i: &str) -> ParseResult<&str, Self>; +} + +/// `OneOrMore` is a container for representing a minimum of one `T`. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct OneOrMore { + pub(crate) contents: Vec, +} + +#[allow(clippy::len_without_is_empty)] +impl OneOrMore { + /// Construct a new `OneOrMore` with `contents`. + /// + /// **NOTE:** that `new` panics if contents is empty. + pub fn new(contents: Vec) -> Self { + if contents.is_empty() { + panic!("OneOrMore requires elements"); + } + + Self { contents } + } + + /// Returns the first element. + pub fn head(&self) -> &T { + self.contents.first().unwrap() + } + + /// Returns the remaining elements after [Self::head]. + pub fn tail(&self) -> &[T] { + &self.contents[1..] + } + + /// Returns the total number of elements. + /// Note that `len` ≥ 1. + pub fn len(&self) -> usize { + self.contents.len() + } +} + +impl Deref for OneOrMore { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + &self.contents + } +} + +impl DerefMut for OneOrMore { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.contents + } +} + +impl OneOrMore { + /// Parse a list of one or more `T`, separated by commas. + /// + /// Returns an error using `msg` if `separated_list1` fails to parse any elements. + pub(crate) fn separated_list1<'a>( + msg: &'static str, + ) -> impl FnMut(&'a str) -> ParseResult<&'a str, Self> { + move |i: &str| { + map( + expect( + msg, + separated_list1(preceded(ws0, char(',')), preceded(ws0, T::parse)), + ), + Self::new, + )(i) + } + } +} + +/// `ZeroOrMore` is a container for representing zero or more elements of type `T`. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ZeroOrMore { + pub(crate) contents: Vec, +} + +impl ZeroOrMore { + /// Construct a new `ZeroOrMore` with `contents`. + pub fn new(contents: Vec) -> Self { + Self { contents } + } + + /// Returns the first element or `None` if the container is empty. + pub fn head(&self) -> Option<&T> { + self.contents.first() + } + + /// Returns the remaining elements after [Self::head]. + pub fn tail(&self) -> &[T] { + if self.contents.len() < 2 { + &[] + } else { + &self.contents[1..] + } + } + + /// Returns the total number of elements in the container. + pub fn len(&self) -> usize { + self.contents.len() + } + + /// Returns true if the container has no elements. + pub fn is_empty(&self) -> bool { + self.contents.is_empty() + } + + /// Takes the vector out of the receiver, leaving a default vector value in its place. + pub fn take(&mut self) -> Vec { + mem::take(&mut self.contents) + } + + /// Replaces the actual value in the receiver by the value given in parameter, + /// returning the old value if present. + pub fn replace(&mut self, value: Vec) -> Vec { + mem::replace(&mut self.contents, value) + } +} + +impl Deref for ZeroOrMore { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + &self.contents + } +} + +impl DerefMut for ZeroOrMore { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.contents + } +} + +impl ZeroOrMore { + /// Parse a list of one or more `T`, separated by commas. + /// + /// Returns an error using `msg` if `separated_list1` fails to parse any elements. + pub(crate) fn separated_list1<'a>( + msg: &'static str, + ) -> impl FnMut(&'a str) -> ParseResult<&'a str, Self> { + move |i: &str| { + map( + expect( + msg, + separated_list1(preceded(ws0, char(',')), preceded(ws0, T::parse)), + ), + Self::new, + )(i) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{assert_error, assert_expect_error}; + use assert_matches::assert_matches; + use nom::character::complete::alphanumeric1; + + impl From<&str> for MeasurementName { + /// Convert a `str` to [`MeasurementName::Name`]. + fn from(s: &str) -> Self { + Self::Name(Identifier::new(s.into())) + } + } + + impl QualifiedMeasurementName { + /// Constructs a new `MeasurementNameExpression` with the specified `name`. + pub fn new(name: MeasurementName) -> Self { + Self { + database: None, + retention_policy: None, + name, + } + } + + /// Constructs a new `MeasurementNameExpression` with the specified `name` and `database`. + pub fn new_db(name: MeasurementName, database: Identifier) -> Self { + Self { + database: Some(database), + retention_policy: None, + name, + } + } + + /// Constructs a new `MeasurementNameExpression` with the specified `name`, `database` and `retention_policy`. + pub fn new_db_rp( + name: MeasurementName, + database: Identifier, + retention_policy: Identifier, + ) -> Self { + Self { + database: Some(database), + retention_policy: Some(retention_policy), + name, + } + } + } + + #[test] + fn test_qualified_measurement_name() { + use MeasurementName::*; + + let (_, got) = qualified_measurement_name("diskio").unwrap(); + assert_eq!( + got, + QualifiedMeasurementName { + database: None, + retention_policy: None, + name: Name("diskio".into()), + } + ); + + let (_, got) = qualified_measurement_name("/diskio/").unwrap(); + assert_eq!( + got, + QualifiedMeasurementName { + database: None, + retention_policy: None, + name: Regex("diskio".into()), + } + ); + + let (_, got) = qualified_measurement_name("telegraf.autogen.diskio").unwrap(); + assert_eq!( + got, + QualifiedMeasurementName { + database: Some("telegraf".into()), + retention_policy: Some("autogen".into()), + name: Name("diskio".into()), + } + ); + + let (_, got) = qualified_measurement_name("telegraf.autogen./diskio/").unwrap(); + assert_eq!( + got, + QualifiedMeasurementName { + database: Some("telegraf".into()), + retention_policy: Some("autogen".into()), + name: Regex("diskio".into()), + } + ); + + let (_, got) = qualified_measurement_name("telegraf..diskio").unwrap(); + assert_eq!( + got, + QualifiedMeasurementName { + database: Some("telegraf".into()), + retention_policy: None, + name: Name("diskio".into()), + } + ); + + let (_, got) = qualified_measurement_name("telegraf../diskio/").unwrap(); + assert_eq!( + got, + QualifiedMeasurementName { + database: Some("telegraf".into()), + retention_policy: None, + name: Regex("diskio".into()), + } + ); + + // With whitespace + let (_, got) = qualified_measurement_name("\"telegraf\".. \"diskio\"").unwrap(); + assert_eq!( + got, + QualifiedMeasurementName { + database: Some("telegraf".into()), + retention_policy: None, + name: Name("diskio".into()), + } + ); + + let (_, got) = + qualified_measurement_name("telegraf. /* a comment */ autogen. diskio").unwrap(); + assert_eq!( + got, + QualifiedMeasurementName { + database: Some("telegraf".into()), + retention_policy: Some("autogen".into()), + name: Name("diskio".into()), + } + ); + + // Whitespace following identifier is not supported + let (rem, got) = qualified_measurement_name("telegraf . autogen. diskio").unwrap(); + assert_eq!(rem, " . autogen. diskio"); + assert_eq!( + got, + QualifiedMeasurementName { + database: None, + retention_policy: None, + name: Name("telegraf".into()), + } + ); + + // Fallible + + // Whitespace preceding regex is not supported + qualified_measurement_name("telegraf.autogen. /diskio/").unwrap_err(); + } + + #[test] + fn test_limit_clause() { + let (_, got) = limit_clause("LIMIT 587").unwrap(); + assert_eq!(*got, 587); + + // case insensitive + let (_, got) = limit_clause("limit 587").unwrap(); + assert_eq!(*got, 587); + + // extra spaces between tokens + let (_, got) = limit_clause("LIMIT 123").unwrap(); + assert_eq!(*got, 123); + + // not digits + assert_expect_error!( + limit_clause("LIMIT from"), + "invalid LIMIT clause, expected unsigned integer" + ); + + // incomplete input + assert_expect_error!( + limit_clause("LIMIT "), + "invalid LIMIT clause, expected unsigned integer" + ); + + // overflow + assert_expect_error!( + limit_clause("LIMIT 34593745733489743985734857394"), + "unable to parse unsigned integer" + ); + } + + #[test] + fn test_offset_clause() { + let (_, got) = offset_clause("OFFSET 587").unwrap(); + assert_eq!(*got, 587); + + // case insensitive + let (_, got) = offset_clause("offset 587").unwrap(); + assert_eq!(*got, 587); + + // extra spaces between tokens + let (_, got) = offset_clause("OFFSET 123").unwrap(); + assert_eq!(*got, 123); + + // not digits + assert_expect_error!( + offset_clause("OFFSET from"), + "invalid OFFSET clause, expected unsigned integer" + ); + + // incomplete input + assert_expect_error!( + offset_clause("OFFSET "), + "invalid OFFSET clause, expected unsigned integer" + ); + + // overflow + assert_expect_error!( + offset_clause("OFFSET 34593745733489743985734857394"), + "unable to parse unsigned integer" + ); + } + + #[test] + fn test_order_by() { + use OrderByClause::*; + + let (_, got) = order_by_clause("ORDER by asc").unwrap(); + assert_eq!(got, Ascending); + + let (_, got) = order_by_clause("ORDER by desc").unwrap(); + assert_eq!(got, Descending); + + // "time" as a quoted identifier + let (_, got) = order_by_clause("ORDER by \"time\" asc").unwrap(); + assert_eq!(got, Ascending); + + let (_, got) = order_by_clause("ORDER by time asc").unwrap(); + assert_eq!(got, Ascending); + + let (_, got) = order_by_clause("ORDER by time desc").unwrap(); + assert_eq!(got, Descending); + + // default case is ascending + let (_, got) = order_by_clause("ORDER by time").unwrap(); + assert_eq!(got, Ascending); + + // case insensitive + let (_, got) = order_by_clause("ORDER by \"TIME\"").unwrap(); + assert_eq!(got, Ascending); + + let (_, got) = order_by_clause("ORDER by Time").unwrap(); + assert_eq!(got, Ascending); + + // does not consume remaining input + let (i, got) = order_by_clause("ORDER by time LIMIT 10").unwrap(); + assert_eq!(got, Ascending); + assert_eq!(i, " LIMIT 10"); + + // Fallible cases + + // Must be "time" identifier + assert_expect_error!( + order_by_clause("ORDER by foo"), + "invalid ORDER BY, expected TIME column" + ); + } + + #[test] + fn test_where_clause() { + // Can parse a WHERE clause + where_clause("WHERE foo = 'bar'").unwrap(); + + // Remaining input is not consumed + let (i, _) = where_clause("WHERE foo = 'bar' LIMIT 10").unwrap(); + assert_eq!(i, " LIMIT 10"); + + // Without unnecessary whitespace + where_clause("WHERE(foo = 'bar')").unwrap(); + + let (rem, _) = where_clause("WHERE/* a comment*/foo = 'bar'").unwrap(); + assert_eq!(rem, ""); + + // Fallible cases + where_clause("WHERE foo = LIMIT 10").unwrap_err(); + where_clause("WHERE").unwrap_err(); + } + + #[test] + fn test_statement_terminator() { + let (i, _) = statement_terminator(";foo").unwrap(); + assert_eq!(i, "foo"); + + let (i, _) = statement_terminator("; foo").unwrap(); + assert_eq!(i, " foo"); + + // Fallible cases + statement_terminator("foo").unwrap_err(); + } + + impl Parser for String { + fn parse(i: &str) -> ParseResult<&str, Self> { + map(alphanumeric1, &str::to_string)(i) + } + } + + type OneOrMoreString = OneOrMore; + + impl Display for OneOrMoreString { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(self.head(), f)?; + for arg in self.tail() { + write!(f, ", {arg}")?; + } + Ok(()) + } + } + + #[test] + #[should_panic(expected = "OneOrMore requires elements")] + fn test_one_or_more() { + let (_, got) = OneOrMoreString::separated_list1("Expects one or more")("foo").unwrap(); + assert_eq!(got.len(), 1); + assert_eq!(got.head(), "foo"); + assert_eq!(*got, vec!["foo"]); // deref + assert_eq!(got.to_string(), "foo"); + + let (_, got) = + OneOrMoreString::separated_list1("Expects one or more")("foo , bar,foobar").unwrap(); + assert_eq!(got.len(), 3); + assert_eq!(got.head(), "foo"); + assert_eq!(got.tail(), vec!["bar", "foobar"]); + assert_eq!(*got, vec!["foo", "bar", "foobar"]); // deref + assert_eq!(got.to_string(), "foo, bar, foobar"); + + // Fallible cases + + assert_expect_error!( + OneOrMoreString::separated_list1("Expects one or more")("+"), + "Expects one or more" + ); + + // should panic + OneOrMoreString::new(vec![]); + } + + type ZeroOrMoreString = ZeroOrMore; + + impl Display for ZeroOrMoreString { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if let Some(first) = self.head() { + Display::fmt(first, f)?; + for arg in self.tail() { + write!(f, ", {arg}")?; + } + } + + Ok(()) + } + } + + #[test] + fn test_zero_or_more() { + let (_, got) = ZeroOrMoreString::separated_list1("Expects one or more")("foo").unwrap(); + assert_eq!(got.len(), 1); + assert_eq!(got.head().unwrap(), "foo"); + assert_eq!(*got, vec!["foo"]); // deref + assert_eq!(got.to_string(), "foo"); + + let (_, got) = + ZeroOrMoreString::separated_list1("Expects one or more")("foo , bar,foobar").unwrap(); + assert_eq!(got.len(), 3); + assert_eq!(got.head().unwrap(), "foo"); + assert_eq!(got.tail(), vec!["bar", "foobar"]); + assert_eq!(*got, vec!["foo", "bar", "foobar"]); // deref + assert_eq!(got.to_string(), "foo, bar, foobar"); + + // should not panic + let got = ZeroOrMoreString::new(vec![]); + assert!(got.is_empty()); + assert_matches!(got.head(), None); + assert_eq!(got.tail().len(), 0); + + // Fallible cases + + assert_expect_error!( + OneOrMoreString::separated_list1("Expects one or more")("+"), + "Expects one or more" + ); + } + + #[test] + fn test_comment_single_line() { + // Comment to EOF + let (rem, _) = comment_single_line("-- this is a test").unwrap(); + assert_eq!(rem, ""); + + // Comment to EOL + let (rem, _) = comment_single_line("-- this is a test\nmore text").unwrap(); + assert_eq!(rem, "\nmore text"); + + // Empty comments + let (rem, _) = comment_single_line("--").unwrap(); + assert_eq!(rem, ""); + let (rem, _) = comment_single_line("--\nSELECT").unwrap(); + assert_eq!(rem, "\nSELECT"); + } + + #[test] + fn test_comment_inline() { + let (rem, _) = comment_inline("/* this is a test */").unwrap(); + assert_eq!(rem, ""); + + let (rem, _) = comment_inline("/* this is a test*/more text").unwrap(); + assert_eq!(rem, "more text"); + + let (rem, _) = comment_inline("/* this\nis a test*/more text").unwrap(); + assert_eq!(rem, "more text"); + + // Ignores embedded /* + let (rem, _) = comment_inline("/* this /* is a test*/more text").unwrap(); + assert_eq!(rem, "more text"); + + // Fallible cases + + assert_expect_error!( + comment_inline("/* this is a test"), + "invalid inline comment, missing closing */" + ); + } + + #[test] + fn test_ws0() { + let (rem, _) = ws0(" -- this is a comment\n/* and some more*/ \t").unwrap(); + assert_eq!(rem, ""); + + let (rem, _) = ws0(" -- this is a comment\n/* and some more*/ \tSELECT").unwrap(); + assert_eq!(rem, "SELECT"); + + // no whitespace + let (rem, _) = ws0("SELECT").unwrap(); + assert_eq!(rem, "SELECT"); + } + + #[test] + fn test_ws1() { + let (rem, _) = ws1(" -- this is a comment\n/* and some more*/ \t").unwrap(); + assert_eq!(rem, ""); + + let (rem, _) = ws1(" -- this is a comment\n/* and some more*/ \tSELECT").unwrap(); + assert_eq!(rem, "SELECT"); + + // Fallible cases + + // Missing whitespace + assert_error!(ws1("SELECT"), Many1); + } +} diff --git a/influxdb_influxql_parser/src/create.rs b/influxdb_influxql_parser/src/create.rs new file mode 100644 index 0000000..6b6ae9c --- /dev/null +++ b/influxdb_influxql_parser/src/create.rs @@ -0,0 +1,236 @@ +//! Types and parsers for the [`CREATE DATABASE`][sql] schema statement. +//! +//! [sql]: https://docs.influxdata.com/influxdb/v1.8/query_language/manage-database/#create-database + +use crate::common::ws1; +use crate::identifier::{identifier, Identifier}; +use crate::internal::{expect, ParseResult}; +use crate::keywords::keyword; +use crate::literal::{duration, unsigned_integer, Duration}; +use crate::statement::Statement; +use nom::branch::alt; +use nom::combinator::{map, opt, peek}; +use nom::sequence::{pair, preceded, tuple}; +use std::fmt::{Display, Formatter}; + +pub(crate) fn create_statement(i: &str) -> ParseResult<&str, Statement> { + preceded( + pair(keyword("CREATE"), ws1), + expect( + "Invalid CREATE statement, expected DATABASE following CREATE", + map(create_database, |s| Statement::CreateDatabase(Box::new(s))), + ), + )(i) +} + +/// Represents a `CREATE DATABASE` statement. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CreateDatabaseStatement { + /// Name of database to be created. + pub name: Identifier, + + /// Duration of retention policy. + pub duration: Option, + + /// Replication factor of retention policy. + pub replication: Option, + + /// Shard duration of retention policy. + pub shard_duration: Option, + + /// Retention policy name. + pub retention_name: Option, +} + +impl CreateDatabaseStatement { + /// Returns true if the "WITH" clause is present. + pub fn has_with_clause(&self) -> bool { + self.duration.is_some() + || self.replication.is_some() + || self.shard_duration.is_some() + || self.retention_name.is_some() + } +} + +impl Display for CreateDatabaseStatement { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "CREATE DATABASE {}", self.name)?; + + if self.has_with_clause() { + f.write_str(" WITH")?; + + if let Some(v) = self.duration { + write!(f, " DURATION {v}")?; + } + + if let Some(v) = self.replication { + write!(f, " REPLICATION {v}")?; + } + + if let Some(v) = self.shard_duration { + write!(f, " SHARD DURATION {v}")?; + } + + if let Some(v) = &self.retention_name { + write!(f, " NAME {v}")?; + } + } + Ok(()) + } +} + +fn create_database(i: &str) -> ParseResult<&str, CreateDatabaseStatement> { + let ( + remaining, + ( + _, // "DATABASE" + name, + opt_with_clause, + ), + ) = tuple(( + keyword("DATABASE"), + identifier, + opt(tuple(( + preceded(ws1, keyword("WITH")), + expect( + "invalid WITH clause, expected \"DURATION\", \"REPLICATION\", \"SHARD\" or \"NAME\"", + peek(preceded( + ws1, + alt(( + keyword("DURATION"), + keyword("REPLICATION"), + keyword("SHARD"), + keyword("NAME"), + )), + )), + ), + opt(preceded( + preceded(ws1, keyword("DURATION")), + expect( + "invalid DURATION clause, expected duration", + preceded(ws1, duration), + ), + )), + opt(preceded( + preceded(ws1, keyword("REPLICATION")), + expect( + "invalid REPLICATION clause, expected unsigned integer", + preceded(ws1, unsigned_integer), + ), + )), + opt(preceded( + pair( + preceded(ws1, keyword("SHARD")), + expect( + "invalid SHARD DURATION clause, expected \"DURATION\"", + preceded(ws1, keyword("DURATION")), + ), + ), + expect( + "invalid SHARD DURATION clause, expected duration", + preceded(ws1, duration), + ), + )), + opt(preceded( + preceded(ws1, keyword("NAME")), + expect( + "invalid NAME clause, expected identifier", + identifier, + ), + )), + ))), + ))(i)?; + + let (_, _, duration, replication, shard_duration, retention_name) = + opt_with_clause.unwrap_or(("", "", None, None, None, None)); + + Ok(( + remaining, + CreateDatabaseStatement { + name, + duration, + replication, + shard_duration, + retention_name, + }, + )) +} + +#[cfg(test)] +mod test { + use super::create_database; + use super::create_statement; + use crate::assert_expect_error; + + #[test] + fn test_create_statement() { + create_statement("CREATE DATABASE telegraf").unwrap(); + } + + #[test] + fn test_create_database() { + let (rem, got) = create_database("DATABASE telegraf").unwrap(); + assert_eq!(rem, ""); + assert_eq!(got.name, "telegraf".into()); + + let (rem, got) = create_database("DATABASE telegraf WITH DURATION 5m").unwrap(); + assert_eq!(rem, ""); + assert_eq!(got.name, "telegraf".into()); + assert_eq!(got.duration.unwrap().to_string(), "5m"); + + let (rem, got) = create_database("DATABASE telegraf WITH REPLICATION 10").unwrap(); + assert_eq!(rem, ""); + assert_eq!(got.name, "telegraf".into()); + assert_eq!(got.replication.unwrap(), 10); + + let (rem, got) = create_database("DATABASE telegraf WITH SHARD DURATION 6m").unwrap(); + assert_eq!(rem, ""); + assert_eq!(got.name, "telegraf".into()); + assert_eq!(got.shard_duration.unwrap().to_string(), "6m"); + + let (rem, got) = create_database("DATABASE telegraf WITH NAME \"5 minutes\"").unwrap(); + assert_eq!(rem, ""); + assert_eq!(got.name, "telegraf".into()); + assert_eq!(got.retention_name.unwrap(), "5 minutes".into()); + + let (rem, got) = create_database("DATABASE telegraf WITH DURATION 5m REPLICATION 10 SHARD DURATION 6m NAME \"5 minutes\"").unwrap(); + assert_eq!(rem, ""); + assert_eq!(got.name, "telegraf".into()); + assert_eq!(got.duration.unwrap().to_string(), "5m"); + assert_eq!(got.replication.unwrap(), 10); + assert_eq!(got.shard_duration.unwrap().to_string(), "6m"); + assert_eq!(got.retention_name.unwrap(), "5 minutes".into()); + + // Fallible + + assert_expect_error!( + create_database("DATABASE telegraf WITH foo"), + "invalid WITH clause, expected \"DURATION\", \"REPLICATION\", \"SHARD\" or \"NAME\"" + ); + + assert_expect_error!( + create_database("DATABASE telegraf WITH DURATION foo"), + "invalid DURATION clause, expected duration" + ); + + assert_expect_error!( + create_database("DATABASE telegraf WITH REPLICATION foo"), + "invalid REPLICATION clause, expected unsigned integer" + ); + + assert_expect_error!( + create_database("DATABASE telegraf WITH SHARD foo"), + "invalid SHARD DURATION clause, expected \"DURATION\"" + ); + + assert_expect_error!( + create_database("DATABASE telegraf WITH SHARD DURATION foo"), + "invalid SHARD DURATION clause, expected duration" + ); + + assert_expect_error!( + create_database("DATABASE telegraf WITH NAME 5"), + "invalid NAME clause, expected identifier" + ); + } +} diff --git a/influxdb_influxql_parser/src/delete.rs b/influxdb_influxql_parser/src/delete.rs new file mode 100644 index 0000000..7a494b3 --- /dev/null +++ b/influxdb_influxql_parser/src/delete.rs @@ -0,0 +1,107 @@ +//! Types and parsers for the [`DELETE`][sql] statement. +//! +//! [sql]: https://docs.influxdata.com/influxdb/v1.8/query_language/manage-database/#delete-series-with-delete + +use crate::common::{where_clause, ws0, ws1, WhereClause}; +use crate::internal::{expect, ParseResult}; +use crate::keywords::keyword; +use crate::simple_from_clause::{delete_from_clause, DeleteFromClause}; +use nom::branch::alt; +use nom::combinator::{map, opt}; +use nom::sequence::{pair, preceded}; +use std::fmt::{Display, Formatter}; + +/// Represents a `DELETE` statement. +#[derive(Clone, Debug, PartialEq)] +pub enum DeleteStatement { + /// A DELETE with a `FROM` clause specifying one or more measurements + /// and an optional `WHERE` clause to restrict which series are deleted. + FromWhere { + /// Represents the `FROM` clause. + from: DeleteFromClause, + + /// Represents the optional `WHERE` clause. + condition: Option, + }, + + /// A `DELETE` with a `WHERE` clause to restrict which series are deleted. + Where(WhereClause), +} + +impl Display for DeleteStatement { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "DELETE")?; + + match self { + Self::FromWhere { from, condition } => { + write!(f, " {from}")?; + if let Some(where_clause) = condition { + write!(f, " {where_clause}")?; + } + } + Self::Where(where_clause) => write!(f, " {where_clause}")?, + }; + + Ok(()) + } +} + +/// Parse a `DELETE` statement. +pub(crate) fn delete_statement(i: &str) -> ParseResult<&str, DeleteStatement> { + // delete ::= "DELETE" ( from_clause where_clause? | where_clause ) + preceded( + keyword("DELETE"), + expect( + "invalid DELETE statement, expected FROM or WHERE", + preceded( + ws1, + alt(( + // delete ::= from_clause where_clause? + map( + pair(delete_from_clause, opt(preceded(ws0, where_clause))), + |(from, condition)| DeleteStatement::FromWhere { from, condition }, + ), + // delete ::= where_clause + map(where_clause, DeleteStatement::Where), + )), + ), + ), + )(i) +} + +#[cfg(test)] +mod test { + use crate::assert_expect_error; + use crate::delete::delete_statement; + + #[test] + fn test_delete() { + // Validate via the Display trait, as we don't need to validate the contents of the + // FROM and / or WHERE clauses, given they are tested in their on modules. + + // Measurement name expressed as an identifier + let (_, got) = delete_statement("DELETE FROM foo").unwrap(); + assert_eq!(got.to_string(), "DELETE FROM foo"); + + // Measurement name expressed as a regular expression + let (_, got) = delete_statement("DELETE FROM /foo/").unwrap(); + assert_eq!(got.to_string(), "DELETE FROM /foo/"); + + let (_, got) = delete_statement("DELETE FROM foo WHERE time > 10").unwrap(); + assert_eq!(got.to_string(), "DELETE FROM foo WHERE time > 10"); + + let (_, got) = delete_statement("DELETE WHERE time > 10").unwrap(); + assert_eq!(got.to_string(), "DELETE WHERE time > 10"); + + // Fallible cases + assert_expect_error!( + delete_statement("DELETE"), + "invalid DELETE statement, expected FROM or WHERE" + ); + + assert_expect_error!( + delete_statement("DELETE FOO"), + "invalid DELETE statement, expected FROM or WHERE" + ); + } +} diff --git a/influxdb_influxql_parser/src/drop.rs b/influxdb_influxql_parser/src/drop.rs new file mode 100644 index 0000000..9080403 --- /dev/null +++ b/influxdb_influxql_parser/src/drop.rs @@ -0,0 +1,78 @@ +//! Types and parsers for the [`DROP MEASUREMENT`][sql] statement. +//! +//! [sql]: https://docs.influxdata.com/influxdb/v1.8/query_language/manage-database/#delete-measurements-with-drop-measurement + +use crate::common::ws1; +use crate::identifier::{identifier, Identifier}; +use crate::internal::{expect, ParseResult}; +use crate::keywords::keyword; +use nom::combinator::map; +use nom::sequence::{pair, preceded}; +use std::fmt::{Display, Formatter}; + +/// Represents a `DROP MEASUREMENT` statement. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DropMeasurementStatement { + /// The name of the measurement to delete. + name: Identifier, +} + +impl Display for DropMeasurementStatement { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "DROP MEASUREMENT {}", self.name) + } +} + +pub(crate) fn drop_statement(i: &str) -> ParseResult<&str, DropMeasurementStatement> { + preceded( + pair(keyword("DROP"), ws1), + expect( + "invalid DROP statement, expected MEASUREMENT", + drop_measurement, + ), + )(i) +} + +fn drop_measurement(i: &str) -> ParseResult<&str, DropMeasurementStatement> { + preceded( + keyword("MEASUREMENT"), + map( + expect( + "invalid DROP MEASUREMENT statement, expected identifier", + identifier, + ), + |name| DropMeasurementStatement { name }, + ), + )(i) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::assert_expect_error; + + #[test] + fn test_drop_statement() { + drop_statement("DROP MEASUREMENT foo").unwrap(); + + // Fallible cases + assert_expect_error!( + drop_statement("DROP foo"), + "invalid DROP statement, expected MEASUREMENT" + ); + } + + #[test] + fn test_drop_measurement() { + let (_, got) = drop_measurement("MEASUREMENT \"foo\"").unwrap(); + assert_eq!(got, DropMeasurementStatement { name: "foo".into() }); + // validate Display + assert_eq!(got.to_string(), "DROP MEASUREMENT foo"); + + // Fallible cases + assert_expect_error!( + drop_measurement("MEASUREMENT 'foo'"), + "invalid DROP MEASUREMENT statement, expected identifier" + ); + } +} diff --git a/influxdb_influxql_parser/src/explain.rs b/influxdb_influxql_parser/src/explain.rs new file mode 100644 index 0000000..d96fbd7 --- /dev/null +++ b/influxdb_influxql_parser/src/explain.rs @@ -0,0 +1,291 @@ +//! Types and parsers for the [`EXPLAIN`][sql] statement. +//! +//! [sql]: https://docs.influxdata.com/influxdb/v1.8/query_language/spec/#explain + +#![allow(dead_code)] // Temporary + +use crate::common::ws1; +use crate::internal::{expect, ParseResult}; +use crate::keywords::keyword; +use crate::statement::{statement, Statement}; +use nom::branch::alt; +use nom::combinator::{map, opt, value}; +use nom::sequence::{preceded, tuple}; +use std::fmt::{Display, Formatter}; + +/// Represents various options for an `EXPLAIN` statement. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ExplainOption { + /// `EXPLAIN VERBOSE statement` + Verbose, + /// `EXPLAIN ANALYZE statement` + Analyze, + /// `EXPLAIN ANALYZE VERBOSE statement` + AnalyzeVerbose, +} + +impl Display for ExplainOption { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Verbose => f.write_str("VERBOSE"), + Self::Analyze => f.write_str("ANALYZE"), + Self::AnalyzeVerbose => f.write_str("ANALYZE VERBOSE"), + } + } +} + +/// Represents an `EXPLAIN` statement. +/// +/// ```text +/// explain ::= "EXPLAIN" explain_options? select_statement +/// explain_options ::= "VERBOSE" | ( "ANALYZE" "VERBOSE"? ) +/// ``` +#[derive(Debug, Clone, PartialEq)] +pub struct ExplainStatement { + /// Represents any options specified for the `EXPLAIN` statement. + pub options: Option, + + /// Represents the `SELECT` statement to be explained and / or analyzed. + pub statement: Box, +} + +impl Display for ExplainStatement { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str("EXPLAIN ")?; + if let Some(options) = &self.options { + write!(f, "{options} ")?; + } + Display::fmt(&self.statement, f) + } +} + +/// Parse an `EXPLAIN` statement. +pub(crate) fn explain_statement(i: &str) -> ParseResult<&str, ExplainStatement> { + map( + tuple(( + keyword("EXPLAIN"), + opt(preceded( + ws1, + alt(( + map( + preceded(keyword("ANALYZE"), opt(preceded(ws1, keyword("VERBOSE")))), + |v| match v { + // If the optional combinator is Some, then it matched VERBOSE + Some(_) => ExplainOption::AnalyzeVerbose, + _ => ExplainOption::Analyze, + }, + ), + value(ExplainOption::Verbose, keyword("VERBOSE")), + )), + )), + ws1, + expect( + "invalid EXPLAIN statement, expected InfluxQL statement", + statement, + ), + )), + |(_, options, _, statement)| ExplainStatement { + options, + statement: Box::new(statement), + }, + )(i) +} + +#[cfg(test)] +mod test { + use crate::assert_expect_error; + use crate::explain::{explain_statement, ExplainOption}; + use assert_matches::assert_matches; + + #[test] + fn test_explain_statement() { + // EXPLAIN SELECT cases + + let (remain, got) = explain_statement("EXPLAIN SELECT val from temp").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(got.options, None); + assert_eq!(got.to_string(), "EXPLAIN SELECT val FROM temp"); + + let (remain, got) = explain_statement("EXPLAIN VERBOSE SELECT val from temp").unwrap(); + assert_eq!(remain, ""); + assert_matches!(&got.options, Some(o) if *o == ExplainOption::Verbose); + assert_eq!(got.to_string(), "EXPLAIN VERBOSE SELECT val FROM temp"); + + let (remain, got) = explain_statement("EXPLAIN ANALYZE SELECT val from temp").unwrap(); + assert_eq!(remain, ""); + assert_matches!(&got.options, Some(o) if *o == ExplainOption::Analyze); + assert_eq!(got.to_string(), "EXPLAIN ANALYZE SELECT val FROM temp"); + + let (remain, got) = + explain_statement("EXPLAIN ANALYZE VERBOSE SELECT val from temp").unwrap(); + assert_eq!(remain, ""); + assert_matches!(&got.options, Some(o) if *o == ExplainOption::AnalyzeVerbose); + assert_eq!( + got.to_string(), + "EXPLAIN ANALYZE VERBOSE SELECT val FROM temp" + ); + + // EXPLAIN SHOW MEASUREMENTS cases + let (remain, got) = explain_statement("EXPLAIN SHOW MEASUREMENTS").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(got.options, None); + assert_eq!(got.to_string(), "EXPLAIN SHOW MEASUREMENTS"); + + let (remain, got) = explain_statement("EXPLAIN VERBOSE SHOW MEASUREMENTS").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(&got.options, Some(o) if *o == ExplainOption::Verbose); + assert_eq!(got.to_string(), "EXPLAIN VERBOSE SHOW MEASUREMENTS"); + + let (remain, got) = explain_statement("EXPLAIN ANALYZE SHOW MEASUREMENTS").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(&got.options, Some(o) if *o == ExplainOption::Analyze); + assert_eq!(got.to_string(), "EXPLAIN ANALYZE SHOW MEASUREMENTS"); + + let (remain, got) = explain_statement("EXPLAIN ANALYZE VERBOSE SHOW MEASUREMENTS").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(&got.options, Some(o) if *o == ExplainOption::AnalyzeVerbose); + assert_eq!(got.to_string(), "EXPLAIN ANALYZE VERBOSE SHOW MEASUREMENTS"); + + // EXPLAIN SHOW TAG KEYS cases + let (remain, got) = explain_statement("EXPLAIN SHOW TAG KEYS").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(got.options, None); + assert_eq!(got.to_string(), "EXPLAIN SHOW TAG KEYS"); + + let (remain, got) = explain_statement("EXPLAIN VERBOSE SHOW TAG KEYS").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(&got.options, Some(o) if *o == ExplainOption::Verbose); + assert_eq!(got.to_string(), "EXPLAIN VERBOSE SHOW TAG KEYS"); + + let (remain, got) = explain_statement("EXPLAIN ANALYZE SHOW TAG KEYS").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(&got.options, Some(o) if *o == ExplainOption::Analyze); + assert_eq!(got.to_string(), "EXPLAIN ANALYZE SHOW TAG KEYS"); + + let (remain, got) = explain_statement("EXPLAIN ANALYZE VERBOSE SHOW TAG KEYS").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(&got.options, Some(o) if *o == ExplainOption::AnalyzeVerbose); + assert_eq!(got.to_string(), "EXPLAIN ANALYZE VERBOSE SHOW TAG KEYS"); + + // EXPLAIN SHOW TAG VALUES cases + let (remain, got) = + explain_statement("EXPLAIN SHOW TAG VALUES WITH KEY = \"Key\"").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(got.options, None); + assert_eq!( + got.to_string(), + "EXPLAIN SHOW TAG VALUES WITH KEY = \"Key\"" + ); + + let (remain, got) = + explain_statement("EXPLAIN VERBOSE SHOW TAG VALUES WITH KEY = \"Key\"").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(&got.options, Some(o) if *o == ExplainOption::Verbose); + assert_eq!( + got.to_string(), + "EXPLAIN VERBOSE SHOW TAG VALUES WITH KEY = \"Key\"" + ); + + let (remain, got) = + explain_statement("EXPLAIN ANALYZE SHOW TAG VALUES WITH KEY = \"Key\"").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(&got.options, Some(o) if *o == ExplainOption::Analyze); + assert_eq!( + got.to_string(), + "EXPLAIN ANALYZE SHOW TAG VALUES WITH KEY = \"Key\"" + ); + + let (remain, got) = + explain_statement("EXPLAIN ANALYZE VERBOSE SHOW TAG VALUES WITH KEY = \"Key\"") + .unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(&got.options, Some(o) if *o == ExplainOption::AnalyzeVerbose); + assert_eq!( + got.to_string(), + "EXPLAIN ANALYZE VERBOSE SHOW TAG VALUES WITH KEY = \"Key\"" + ); + + // EXPLAIN SHOW FIELD KEYS cases + let (remain, got) = explain_statement("EXPLAIN SHOW FIELD KEYS").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(got.options, None); + assert_eq!(got.to_string(), "EXPLAIN SHOW FIELD KEYS"); + + let (remain, got) = explain_statement("EXPLAIN VERBOSE SHOW FIELD KEYS").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(&got.options, Some(o) if *o == ExplainOption::Verbose); + assert_eq!(got.to_string(), "EXPLAIN VERBOSE SHOW FIELD KEYS"); + + let (remain, got) = explain_statement("EXPLAIN ANALYZE SHOW FIELD KEYS").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(&got.options, Some(o) if *o == ExplainOption::Analyze); + assert_eq!(got.to_string(), "EXPLAIN ANALYZE SHOW FIELD KEYS"); + + let (remain, got) = explain_statement("EXPLAIN ANALYZE VERBOSE SHOW FIELD KEYS").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(&got.options, Some(o) if *o == ExplainOption::AnalyzeVerbose); + assert_eq!(got.to_string(), "EXPLAIN ANALYZE VERBOSE SHOW FIELD KEYS"); + + // EXPLAIN SHOW RETENTION POLICIES cases + let (remain, got) = explain_statement("EXPLAIN SHOW RETENTION POLICIES").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(got.options, None); + assert_eq!(got.to_string(), "EXPLAIN SHOW RETENTION POLICIES"); + + let (remain, got) = explain_statement("EXPLAIN VERBOSE SHOW RETENTION POLICIES").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(&got.options, Some(o) if *o == ExplainOption::Verbose); + assert_eq!(got.to_string(), "EXPLAIN VERBOSE SHOW RETENTION POLICIES"); + + let (remain, got) = explain_statement("EXPLAIN ANALYZE SHOW RETENTION POLICIES").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(&got.options, Some(o) if *o == ExplainOption::Analyze); + assert_eq!(got.to_string(), "EXPLAIN ANALYZE SHOW RETENTION POLICIES"); + + let (remain, got) = + explain_statement("EXPLAIN ANALYZE VERBOSE SHOW RETENTION POLICIES").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(&got.options, Some(o) if *o == ExplainOption::AnalyzeVerbose); + assert_eq!( + got.to_string(), + "EXPLAIN ANALYZE VERBOSE SHOW RETENTION POLICIES" + ); + + // EXPLAIN SHOW DATABASES cases + let (remain, got) = explain_statement("EXPLAIN SHOW DATABASES").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(got.options, None); + assert_eq!(got.to_string(), "EXPLAIN SHOW DATABASES"); + + let (remain, got) = explain_statement("EXPLAIN VERBOSE SHOW DATABASES").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(&got.options, Some(o) if *o == ExplainOption::Verbose); + assert_eq!(got.to_string(), "EXPLAIN VERBOSE SHOW DATABASES"); + + let (remain, got) = explain_statement("EXPLAIN ANALYZE SHOW DATABASES").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(&got.options, Some(o) if *o == ExplainOption::Analyze); + assert_eq!(got.to_string(), "EXPLAIN ANALYZE SHOW DATABASES"); + + let (remain, got) = explain_statement("EXPLAIN ANALYZE VERBOSE SHOW DATABASES").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(&got.options, Some(o) if *o == ExplainOption::AnalyzeVerbose); + assert_eq!(got.to_string(), "EXPLAIN ANALYZE VERBOSE SHOW DATABASES"); + + // NOTE: Nested EXPLAIN is valid; DataFusion will throw a "No Nested EXPLAIN" error later + let (remain, got) = + explain_statement("EXPLAIN ANALYZE EXPLAIN SELECT val from temp").unwrap(); + assert_eq!(remain, ""); // assert that all input was consumed + assert_matches!(&got.options, Some(o) if *o == ExplainOption::Analyze); + assert_eq!( + got.to_string(), + "EXPLAIN ANALYZE EXPLAIN SELECT val FROM temp" + ); + + // surfaces statement-specific errors + assert_expect_error!( + explain_statement("EXPLAIN ANALYZE SELECT cpu FROM 'foo'"), + "invalid FROM clause, expected identifier, regular expression or subquery" + ); + } +} diff --git a/influxdb_influxql_parser/src/expression.rs b/influxdb_influxql_parser/src/expression.rs new file mode 100644 index 0000000..ee765cc --- /dev/null +++ b/influxdb_influxql_parser/src/expression.rs @@ -0,0 +1,14 @@ +//! Types and parsers for arithmetic and conditional expressions. + +pub use arithmetic::*; +pub use conditional::*; + +/// Provides arithmetic expression parsing. +pub mod arithmetic; +/// Provides conditional expression parsing. +pub mod conditional; +/// Provides APIs to traverse an expression tree using closures. +pub mod walk; + +#[cfg(test)] +mod test_util; diff --git a/influxdb_influxql_parser/src/expression/arithmetic.rs b/influxdb_influxql_parser/src/expression/arithmetic.rs new file mode 100644 index 0000000..264370f --- /dev/null +++ b/influxdb_influxql_parser/src/expression/arithmetic.rs @@ -0,0 +1,1113 @@ +use crate::common::ws0; +use crate::identifier::unquoted_identifier; +use crate::internal::{expect, Error, ParseError, ParseResult}; +use crate::keywords::keyword; +use crate::literal::{literal_regex, Duration}; +use crate::timestamp::Timestamp; +use crate::{ + identifier::{identifier, Identifier}, + literal::Literal, + parameter::BindParameter, +}; +use nom::branch::alt; +use nom::bytes::complete::tag; +use nom::character::complete::char; +use nom::combinator::{cut, map, opt, value}; +use nom::multi::{many0, separated_list0}; +use nom::sequence::{delimited, pair, preceded, separated_pair, terminated, tuple}; +use num_traits::cast; +use std::fmt::{Display, Formatter, Write}; +use std::ops::Neg; + +/// Reference to a tag or field key. +#[derive(Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)] +pub struct VarRef { + /// The name of the tag or field. + pub name: Identifier, + + /// An optional data type selection specified using the `::` operator. + /// + /// When the `::` operator follows an identifier, it instructs InfluxQL to fetch + /// only data of the matching data type. + /// + /// The `::` operator appears after an [`Identifier`] and may be described using + /// the following EBNF: + /// + /// ```text + /// variable_ref ::= identifier ( "::" data_type )? + /// data_type ::= "float" | "integer" | "boolean" | "string" | "tag" | "field" + /// ``` + /// + /// For example: + /// + /// ```text + /// SELECT foo::field, host::tag, usage_idle::integer, idle::boolean FROM cpu + /// ``` + /// + /// Specifies the following: + /// + /// * `foo::field` will return a field of any data type named `foo` + /// * `host::tag` will return a tag named `host` + /// * `usage_idle::integer` will return either a float or integer field named `usage_idle`, + /// and casting it to an `integer` + /// * `idle::boolean` will return a field named `idle` that has a matching data type of + /// `boolean` + pub data_type: Option, +} + +impl Display for VarRef { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let Self { name, data_type } = self; + write!(f, "{name}")?; + if let Some(d) = data_type { + write!(f, "::{d}")?; + } + Ok(()) + } +} + +/// Function call +#[derive(Clone, Debug, PartialEq)] +pub struct Call { + /// Represents the name of the function call. + pub name: String, + + /// Represents the list of arguments to the function call. + pub args: Vec, +} + +impl Display for Call { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let Self { name, args } = self; + write!(f, "{name}(")?; + if !args.is_empty() { + let args = args.as_slice(); + write!(f, "{}", args[0])?; + for arg in &args[1..] { + write!(f, ", {arg}")?; + } + } + write!(f, ")") + } +} + +/// Binary operations, such as `1 + 2`. +#[derive(Clone, Debug, PartialEq)] +pub struct Binary { + /// Represents the left-hand side of the binary expression. + pub lhs: Box, + /// Represents the operator to apply to the binary expression. + pub op: BinaryOperator, + /// Represents the right-hand side of the binary expression. + pub rhs: Box, +} + +impl Display for Binary { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let Self { lhs, op, rhs } = self; + write!(f, "{lhs} {op} {rhs}") + } +} + +/// An InfluxQL arithmetic expression. +#[derive(Clone, Debug, PartialEq)] +pub enum Expr { + /// Reference to a tag or field key. + VarRef(VarRef), + + /// BindParameter identifier + BindParameter(BindParameter), + + /// Literal value such as 'foo', 5 or /^(a|b)$/ + Literal(Literal), + + /// A literal wildcard (`*`) with an optional data type selection. + Wildcard(Option), + + /// A DISTINCT `` expression. + Distinct(Identifier), + + /// Function call + Call(Call), + + /// Binary operations, such as `1 + 2`. + Binary(Binary), + + /// Nested expression, such as (foo = 'bar') or (1) + Nested(Box), +} + +impl From for Expr { + fn from(v: Literal) -> Self { + Self::Literal(v) + } +} + +impl From for Expr { + fn from(v: i64) -> Self { + Self::Literal(v.into()) + } +} + +impl From for Expr { + fn from(v: u64) -> Self { + Self::Literal(v.into()) + } +} + +impl From for Expr { + fn from(v: f64) -> Self { + Self::Literal(v.into()) + } +} + +impl From for Box { + fn from(v: u64) -> Self { + Self::new(v.into()) + } +} + +impl From for Box { + fn from(v: i64) -> Self { + Self::new(v.into()) + } +} + +impl From for Box { + fn from(v: i32) -> Self { + Self::new((v as i64).into()) + } +} + +impl Display for Expr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::VarRef(v) => write!(f, "{v}"), + Self::BindParameter(v) => write!(f, "{v}"), + Self::Literal(v) => write!(f, "{v}"), + Self::Binary(v) => write!(f, "{v}"), + Self::Nested(v) => write!(f, "({v})"), + Self::Call(v) => write!(f, "{v}"), + Self::Wildcard(Some(v)) => write!(f, "*::{v}"), + Self::Wildcard(None) => f.write_char('*'), + Self::Distinct(v) => write!(f, "DISTINCT {v}"), + } + } +} + +/// Traits to help creating InfluxQL [`Expr`]s containing +/// a [`VarRef`]. +pub trait AsVarRefExpr { + /// Creates an InfluxQL [`VarRef`] expression. + fn to_var_ref_expr(&self) -> Expr; +} + +impl AsVarRefExpr for str { + fn to_var_ref_expr(&self) -> Expr { + Expr::VarRef(VarRef { + name: self.into(), + data_type: None, + }) + } +} + +/// Specifies the data type of a wildcard (`*`) when using the `::` operator. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum WildcardType { + /// Indicates the wildcard refers to tags only. + Tag, + + /// Indicates the wildcard refers to fields only. + Field, +} + +impl Display for WildcardType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Tag => f.write_str("tag"), + Self::Field => f.write_str("field"), + } + } +} + +/// Represents the primitive data types of a [`Expr::VarRef`] when specified +/// using a [cast operation][cast]. +/// +/// InfluxQL only supports casting between [`Self::Float`] and [`Self::Integer`] types. +/// +/// [cast]: https://docs.influxdata.com/influxdb/v1.8/query_language/explore-data/#cast-operations +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum VarRefDataType { + /// Represents a 64-bit float. + Float, + /// Represents a 64-bit integer. + Integer, + /// Represents a 64-bit unsigned integer. + Unsigned, + /// Represents a UTF-8 string. + String, + /// Represents a boolean. + Boolean, + /// Represents a field. + Field, + /// Represents a tag. + Tag, + /// Represents a timestamp. + Timestamp, +} + +impl VarRefDataType { + /// Returns true if the receiver is a data type that identifies as a field type. + pub fn is_field_type(&self) -> bool { + *self < Self::Tag + } + + /// Returns true if the receiver is a data type that identifies as a tag type. + pub fn is_tag_type(&self) -> bool { + *self == Self::Tag + } + + /// Returns true if the receiver is a numeric type. + pub fn is_numeric_type(&self) -> bool { + *self <= Self::Unsigned + } +} + +impl Display for VarRefDataType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Float => f.write_str("float"), + Self::Integer => f.write_str("integer"), + Self::Unsigned => f.write_str("unsigned"), + Self::String => f.write_str("string"), + Self::Boolean => f.write_str("boolean"), + Self::Tag => f.write_str("tag"), + Self::Field => f.write_str("field"), + Self::Timestamp => f.write_str("timestamp"), + } + } +} + +/// An InfluxQL unary operator. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum UnaryOperator { + /// Represents the unary `+` operator. + Plus, + /// Represents the unary `-` operator. + Minus, +} + +impl Display for UnaryOperator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Plus => f.write_char('+'), + Self::Minus => f.write_char('-'), + } + } +} + +/// An InfluxQL binary operators. +#[derive(Clone, Debug, Copy, PartialEq, Eq)] +pub enum BinaryOperator { + /// Represents the `+` operator. + Add, + /// Represents the `-` operator. + Sub, + /// Represents the `*` operator. + Mul, + /// Represents the `/` operator. + Div, + /// Represents the `%` or modulus operator. + Mod, + /// Represents the `&` or bitwise-and operator. + BitwiseAnd, + /// Represents the `|` or bitwise-or operator. + BitwiseOr, + /// Represents the `^` or bitwise-xor operator. + BitwiseXor, +} + +impl BinaryOperator { + fn reduce_number(&self, lhs: T, rhs: T) -> T + where + T: num_traits::NumOps, + T: num_traits::identities::Zero, + { + match self { + Self::Add => lhs + rhs, + Self::Sub => lhs - rhs, + Self::Mul => lhs * rhs, + // Divide by zero yields zero per + // https://github.com/influxdata/influxql/blob/1ba470371ec093d57a726b143fe6ccbacf1b452b/ast.go#L5216-L5218 + Self::Div if rhs.is_zero() => T::zero(), + Self::Div => lhs / rhs, + Self::Mod => lhs % rhs, + _ => unreachable!(), + } + } + + /// Return a value by applying the operation defined by the receiver. + pub fn reduce(&self, lhs: T, rhs: T) -> T { + match self { + Self::Add | Self::Sub | Self::Mul | Self::Div | Self::Mod => { + self.reduce_number(lhs, rhs) + } + Self::BitwiseAnd => lhs & rhs, + Self::BitwiseOr => lhs | rhs, + Self::BitwiseXor => lhs ^ rhs, + } + } + + /// Return a value by applying the operation defined by the receiver or [`None`] + /// if the operation is not supported. + pub fn try_reduce(&self, lhs: T, rhs: U) -> Option + where + T: num_traits::Float, + U: num_traits::NumOps, + U: num_traits::NumCast, + { + match self { + Self::Add | Self::Sub | Self::Mul | Self::Div | Self::Mod => Some(self.reduce_number( + lhs, + match cast(rhs) { + Some(v) => v, + None => return None, + }, + )), + _ => None, + } + } +} + +impl Display for BinaryOperator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Add => f.write_char('+'), + Self::Sub => f.write_char('-'), + Self::Mul => f.write_char('*'), + Self::Div => f.write_char('/'), + Self::Mod => f.write_char('%'), + Self::BitwiseAnd => f.write_char('&'), + Self::BitwiseOr => f.write_char('|'), + Self::BitwiseXor => f.write_char('^'), + } + } +} + +/// Parse a unary expression. +fn unary(i: &str) -> ParseResult<&str, Expr> +where + T: ArithmeticParsers, +{ + let (i, op) = preceded( + ws0, + alt(( + value(UnaryOperator::Plus, char('+')), + value(UnaryOperator::Minus, char('-')), + )), + )(i)?; + + let (i, e) = factor::(i)?; + + // Unary minus is expressed by negating existing literals, + // or producing a binary arithmetic expression that multiplies + // Expr `e` by -1 + let e = if op == UnaryOperator::Minus { + match e { + Expr::Literal(Literal::Float(v)) => Expr::Literal(Literal::Float(v.neg())), + Expr::Literal(Literal::Integer(v)) => Expr::Literal(Literal::Integer(v.neg())), + Expr::Literal(Literal::Duration(v)) => Expr::Literal(Literal::Duration((v.0.neg()).into())), + Expr::Literal(Literal::Unsigned(v)) => { + if v == (i64::MAX as u64) + 1 { + // The minimum i64 is parsed as a Literal::Unsigned, as it exceeds + // int64::MAX, so we explicitly handle that case per + // https://github.com/influxdata/influxql/blob/7e7d61973256ffeef4b99edd0a89f18a9e52fa2d/parser.go#L2750-L2755 + Expr::Literal(Literal::Integer(i64::MIN)) + } else { + return Err(nom::Err::Failure(Error::from_message( + i, + "constant overflows signed integer", + ))); + } + }, + v @ Expr::VarRef { .. } | v @ Expr::Call { .. } | v @ Expr::Nested(..) | v @ Expr::BindParameter(..) => { + Expr::Binary(Binary { + lhs: Box::new(Expr::Literal(Literal::Integer(-1))), + op: BinaryOperator::Mul, + rhs: Box::new(v), + }) + } + _ => { + return Err(nom::Err::Failure(Error::from_message( + i, + "unexpected unary expression: expected literal integer, float, duration, field, function or parenthesis", + ))) + } + } + } else { + e + }; + + Ok((i, e)) +} + +/// Parse a parenthesis expression. +fn parens(i: &str) -> ParseResult<&str, Expr> +where + T: ArithmeticParsers, +{ + delimited( + preceded(ws0, char('(')), + map(arithmetic::, |e| Expr::Nested(e.into())), + preceded(ws0, char(')')), + )(i) +} + +/// Parse a function call expression. +/// +/// The `name` field of the [`Expr::Call`] variant is guaranteed to be in lowercase. +pub(crate) fn call_expression(i: &str) -> ParseResult<&str, Expr> +where + T: ArithmeticParsers, +{ + map( + separated_pair( + // special case to handle `DISTINCT`, which is allowed as an identifier + // in a call expression + map(alt((unquoted_identifier, keyword("DISTINCT"))), |n| { + n.to_ascii_lowercase() + }), + ws0, + delimited( + char('('), + alt(( + // A single regular expression to match 0 or more field keys + map(preceded(ws0, literal_regex), |re| vec![re.into()]), + // A list of Expr, separated by commas + separated_list0(preceded(ws0, char(',')), arithmetic::), + )), + cut(preceded(ws0, char(')'))), + ), + ), + |(name, args)| Expr::Call(Call { name, args }), + )(i) +} + +/// Parse a segmented identifier +/// +/// ```text +/// segmented_identifier ::= identifier | +/// ( identifier "." identifier ) | +/// ( identifier "." identifier? "." identifier ) +/// ``` +fn segmented_identifier(i: &str) -> ParseResult<&str, Identifier> { + let (remaining, (opt_prefix, name)) = pair( + opt(alt(( + // ident2 "." ident1 "." + map( + pair( + terminated(identifier, tag(".")), + terminated(identifier, tag(".")), + ), + |(ident2, ident1)| (Some(ident2), Some(ident1)), + ), + // identifier ".." + map(terminated(identifier, tag("..")), |ident2| { + (Some(ident2), None) + }), + // identifier "." + map(terminated(identifier, tag(".")), |ident1| { + (None, Some(ident1)) + }), + ))), + identifier, + )(i)?; + + Ok(( + remaining, + match opt_prefix { + Some((None, Some(ident1))) => format!("{}.{}", ident1.0, name.0).into(), + Some((Some(ident2), None)) => format!("{}..{}", ident2.0, name.0).into(), + Some((Some(ident2), Some(ident1))) => { + format!("{}.{}.{}", ident2.0, ident1.0, name.0).into() + } + _ => name, + }, + )) +} + +/// Parse a variable reference, which is a segmented identifier followed by an optional cast expression. +pub(crate) fn var_ref(i: &str) -> ParseResult<&str, Expr> { + map( + pair( + segmented_identifier, + opt(preceded( + tag("::"), + expect( + "invalid data type for tag or field reference, expected float, integer, unsigned, string, boolean, field, tag", + alt(( + value(VarRefDataType::Float, keyword("FLOAT")), + value(VarRefDataType::Integer, keyword("INTEGER")), + value(VarRefDataType::Unsigned, keyword("UNSIGNED")), + value(VarRefDataType::String, keyword("STRING")), + value(VarRefDataType::Boolean, keyword("BOOLEAN")), + value(VarRefDataType::Tag, keyword("TAG")), + value(VarRefDataType::Field, keyword("FIELD")) + )) + ) + )), + ), + |(name, data_type)| Expr::VarRef(VarRef { name, data_type }), + )(i) +} + +/// Parse precedence priority 1 operators. +/// +/// These are the highest precedence operators, and include parenthesis and the unary operators. +fn factor(i: &str) -> ParseResult<&str, Expr> +where + T: ArithmeticParsers, +{ + alt((unary::, parens::, T::operand))(i) +} + +/// Parse arithmetic, precedence priority 2 operators. +/// +/// This includes the multiplication, division, bitwise and, and modulus operators. +fn term(i: &str) -> ParseResult<&str, Expr> +where + T: ArithmeticParsers, +{ + let (input, left) = factor::(i)?; + let (input, remaining) = many0(tuple(( + preceded( + ws0, + alt(( + value(BinaryOperator::Mul, char('*')), + value(BinaryOperator::Div, char('/')), + value(BinaryOperator::BitwiseAnd, char('&')), + value(BinaryOperator::Mod, char('%')), + )), + ), + factor::, + )))(input)?; + Ok((input, reduce_expr(left, remaining))) +} + +/// Parse an arithmetic expression. +/// +/// This includes the addition, subtraction, bitwise or, and bitwise xor operators. +pub(crate) fn arithmetic(i: &str) -> ParseResult<&str, Expr> +where + T: ArithmeticParsers, +{ + let (input, left) = term::(i)?; + let (input, remaining) = many0(tuple(( + preceded( + ws0, + alt(( + value(BinaryOperator::Add, char('+')), + value(BinaryOperator::Sub, char('-')), + value(BinaryOperator::BitwiseOr, char('|')), + value(BinaryOperator::BitwiseXor, char('^')), + )), + ), + cut(term::), + )))(input)?; + Ok((input, reduce_expr(left, remaining))) +} + +/// A trait for customizing arithmetic parsers. +pub(crate) trait ArithmeticParsers { + /// Parse an operand of an arithmetic expression. + fn operand(i: &str) -> ParseResult<&str, Expr>; +} + +/// Folds `expr` and `remainder` into a [Expr::Binary] tree. +fn reduce_expr(expr: Expr, remainder: Vec<(BinaryOperator, Expr)>) -> Expr { + remainder.into_iter().fold(expr, |lhs, val| { + Expr::Binary(Binary { + lhs: lhs.into(), + op: val.0, + rhs: val.1.into(), + }) + }) +} + +/// Trait for converting a type to a [`Expr::Literal`] expression. +pub trait LiteralExpr { + /// Convert the receiver to a literal expression. + fn lit(self) -> Expr; +} + +/// Convert `v` to a literal expression. +pub fn lit(v: T) -> Expr { + v.lit() +} + +impl LiteralExpr for Literal { + fn lit(self) -> Expr { + Expr::Literal(self) + } +} + +impl LiteralExpr for Duration { + fn lit(self) -> Expr { + Expr::Literal(Literal::Duration(self)) + } +} + +impl LiteralExpr for bool { + fn lit(self) -> Expr { + Expr::Literal(Literal::Boolean(self)) + } +} + +impl LiteralExpr for i64 { + fn lit(self) -> Expr { + Expr::Literal(Literal::Integer(self)) + } +} + +impl LiteralExpr for f64 { + fn lit(self) -> Expr { + Expr::Literal(Literal::Float(self)) + } +} + +impl LiteralExpr for String { + fn lit(self) -> Expr { + Expr::Literal(Literal::String(self)) + } +} + +impl LiteralExpr for Timestamp { + fn lit(self) -> Expr { + Expr::Literal(Literal::Timestamp(self)) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::literal::literal_no_regex; + use crate::parameter::parameter; + use crate::{assert_expect_error, assert_failure, binary_op, nested, param, var_ref}; + + struct TestParsers; + + impl ArithmeticParsers for TestParsers { + fn operand(i: &str) -> ParseResult<&str, Expr> { + preceded( + ws0, + alt(( + map(literal_no_regex, Expr::Literal), + var_ref, + map(parameter, Expr::BindParameter), + )), + )(i) + } + } + + fn arithmetic_expression(i: &str) -> ParseResult<&str, Expr> { + arithmetic::(i) + } + + #[test] + fn test_arithmetic() { + let (_, got) = arithmetic_expression("5 + 51").unwrap(); + assert_eq!(got, binary_op!(5, Add, 51)); + + let (_, got) = arithmetic_expression("5 + $foo").unwrap(); + assert_eq!(got, binary_op!(5, Add, param!("foo"))); + + // Following two tests validate that operators of higher precedence + // are nested deeper in the AST. + + let (_, got) = arithmetic_expression("5 % -3 | 2").unwrap(); + assert_eq!(got, binary_op!(binary_op!(5, Mod, -3), BitwiseOr, 2)); + + let (_, got) = arithmetic_expression("-3 | 2 % 5").unwrap(); + assert_eq!(got, binary_op!(-3, BitwiseOr, binary_op!(2, Mod, 5))); + + let (_, got) = arithmetic_expression("5 % 2 | -3").unwrap(); + assert_eq!(got, binary_op!(binary_op!(5, Mod, 2), BitwiseOr, -3)); + + let (_, got) = arithmetic_expression("2 | -3 % 5").unwrap(); + assert_eq!(got, binary_op!(2, BitwiseOr, binary_op!(-3, Mod, 5))); + + let (_, got) = arithmetic_expression("5 - -(3 | 2)").unwrap(); + assert_eq!( + got, + binary_op!( + 5, + Sub, + binary_op!(-1, Mul, nested!(binary_op!(3, BitwiseOr, 2))) + ) + ); + + let (_, got) = arithmetic_expression("2 | 5 % 3").unwrap(); + assert_eq!(got, binary_op!(2, BitwiseOr, binary_op!(5, Mod, 3))); + + // Expressions are still valid when unnecessary whitespace is omitted + + let (_, got) = arithmetic_expression("5+51").unwrap(); + assert_eq!(got, binary_op!(5, Add, 51)); + + let (_, got) = arithmetic_expression("5+$foo").unwrap(); + assert_eq!(got, binary_op!(5, Add, param!("foo"))); + + let (_, got) = arithmetic_expression("5- -(3|2)").unwrap(); + assert_eq!( + got, + binary_op!( + 5, + Sub, + binary_op!(-1, Mul, nested!(binary_op!(3, BitwiseOr, 2))) + ) + ); + + // whitespace is not significant between unary operators + let (_, got) = arithmetic_expression("5+-(3|2)").unwrap(); + assert_eq!( + got, + binary_op!( + 5, + Add, + binary_op!(-1, Mul, nested!(binary_op!(3, BitwiseOr, 2))) + ) + ); + + // Test unary max signed + let (_, got) = arithmetic_expression("-9223372036854775808").unwrap(); + assert_eq!(got, Expr::Literal(Literal::Integer(-9223372036854775808))); + + // Fallible cases + + // invalid operator / incomplete expression + assert_failure!(arithmetic_expression("5 || 3")); + assert_failure!(arithmetic_expression("5+--(3|2)")); + // exceeds i64::MIN + assert_failure!(arithmetic_expression("-9223372036854775809")); + } + + #[test] + fn test_var_ref() { + let (_, got) = var_ref("foo").unwrap(); + assert_eq!(got, var_ref!("foo")); + + // Whilst this is parsed as a 3-part name, it is treated as a quoted string 🙄 + // VarRefs are parsed as segmented identifiers + // + // * https://github.com/influxdata/influxql/blob/7e7d61973256ffeef4b99edd0a89f18a9e52fa2d/parser.go#L2515-L2516 + // + // and then the segments are joined as a single string + // + // * https://github.com/influxdata/influxql/blob/7e7d61973256ffeef4b99edd0a89f18a9e52fa2d/parser.go#L2551 + let (rem, got) = var_ref("db.rp.foo").unwrap(); + assert_eq!(got, var_ref!("db.rp.foo")); + assert_eq!(got.to_string(), r#""db.rp.foo""#); + assert_eq!(rem, ""); + + // with cast operators + + let (_, got) = var_ref("foo::float").unwrap(); + assert_eq!(got, var_ref!("foo", Float)); + let (_, got) = var_ref("foo::integer").unwrap(); + assert_eq!(got, var_ref!("foo", Integer)); + let (_, got) = var_ref("foo::unsigned").unwrap(); + assert_eq!(got, var_ref!("foo", Unsigned)); + let (_, got) = var_ref("foo::string").unwrap(); + assert_eq!(got, var_ref!("foo", String)); + let (_, got) = var_ref("foo::boolean").unwrap(); + assert_eq!(got, var_ref!("foo", Boolean)); + let (_, got) = var_ref("foo::field").unwrap(); + assert_eq!(got, var_ref!("foo", Field)); + let (_, got) = var_ref("foo::tag").unwrap(); + assert_eq!(got, var_ref!("foo", Tag)); + + // Fallible cases + + assert_expect_error!(var_ref("foo::invalid"), "invalid data type for tag or field reference, expected float, integer, unsigned, string, boolean, field, tag"); + } + + #[test] + fn test_spacing_and_remaining_input() { + // Validate that the remaining input is returned + let (got, _) = arithmetic_expression("foo - 1 + 2 LIMIT 10").unwrap(); + assert_eq!(got, " LIMIT 10"); + + // Any whitespace preceding the expression is consumed + let (got, _) = arithmetic_expression(" foo - 1 + 2").unwrap(); + assert_eq!(got, ""); + + // Various whitespace separators are supported between tokens + let (got, _) = arithmetic_expression("foo\n | 1 \t + \n \t3").unwrap(); + assert!(got.is_empty()) + } + + #[test] + fn test_segmented_identifier() { + // Unquoted + let (rem, id) = segmented_identifier("part0").unwrap(); + assert_eq!(rem, ""); + assert_eq!(id.to_string(), "part0"); + + // id.id + let (rem, id) = segmented_identifier("part1.part0").unwrap(); + assert_eq!(rem, ""); + assert_eq!(id.to_string(), "\"part1.part0\""); + + // id..id + let (rem, id) = segmented_identifier("part2..part0").unwrap(); + assert_eq!(rem, ""); + assert_eq!(id.to_string(), "\"part2..part0\""); + + // id.id.id + let (rem, id) = segmented_identifier("part2.part1.part0").unwrap(); + assert_eq!(rem, ""); + assert_eq!(id.to_string(), "\"part2.part1.part0\""); + + // "id"."id".id + let (rem, id) = segmented_identifier(r#""part 2"."part 1".part0"#).unwrap(); + assert_eq!(rem, ""); + assert_eq!(id.to_string(), "\"part 2.part 1.part0\""); + + // Only parses 3 segments + let (rem, id) = segmented_identifier("part2.part1.part0.foo").unwrap(); + assert_eq!(rem, ".foo"); + assert_eq!(id.to_string(), "\"part2.part1.part0\""); + + // Quoted + let (rem, id) = segmented_identifier("\"part0\"").unwrap(); + assert_eq!(rem, ""); + assert_eq!(id.to_string(), "part0"); + + // Additional test cases, with compatibility proven via https://go.dev/play/p/k2150CJocVl + + let (rem, id) = segmented_identifier(r#""part" 2"."part 1".part0"#).unwrap(); + assert_eq!(rem, r#" 2"."part 1".part0"#); + assert_eq!(id.to_string(), "part"); + + let (rem, id) = segmented_identifier(r#""part" 2."part 1".part0"#).unwrap(); + assert_eq!(rem, r#" 2."part 1".part0"#); + assert_eq!(id.to_string(), "part"); + + let (rem, id) = segmented_identifier(r#""part "2"."part 1".part0"#).unwrap(); + assert_eq!(rem, r#"2"."part 1".part0"#); + assert_eq!(id.to_string(), r#""part ""#); + + let (rem, id) = segmented_identifier(r#""part ""2"."part 1".part0"#).unwrap(); + assert_eq!(rem, r#""2"."part 1".part0"#); + assert_eq!(id.to_string(), r#""part ""#); + } + + #[test] + fn test_display_expr() { + #[track_caller] + fn assert_display_expr(input: &str, expected: &str) { + let (_, e) = arithmetic_expression(input).unwrap(); + assert_eq!(e.to_string(), expected); + } + + assert_display_expr("5 + 51", "5 + 51"); + assert_display_expr("5 + -10", "5 + -10"); + assert_display_expr("-(5 % 6)", "-1 * (5 % 6)"); + + // vary spacing + assert_display_expr("( 5 + 6 ) * -( 7+ 8)", "(5 + 6) * -1 * (7 + 8)"); + + // multiple unary and parenthesis + assert_display_expr("(-(5 + 6) & -+( 7 + 8 ))", "(-1 * (5 + 6) & -1 * (7 + 8))"); + + // unquoted identifier + assert_display_expr("foo + 5", "foo + 5"); + + // identifier, negated + assert_display_expr("-foo + 5", "-1 * foo + 5"); + + // bind parameter identifier + assert_display_expr("foo + $0", "foo + $0"); + + // quoted identifier + assert_display_expr(r#""foo" + 'bar'"#, r#"foo + 'bar'"#); + + // quoted identifier, negated + assert_display_expr(r#"-"foo" + 'bar'"#, r#"-1 * foo + 'bar'"#); + + // quoted identifier with spaces, negated + assert_display_expr(r#"-"foo bar" + 'bar'"#, r#"-1 * "foo bar" + 'bar'"#); + + // Duration + assert_display_expr("6h30m", "6h30m"); + + // Negated + assert_display_expr("- 6h30m", "-6h30m"); + + // Validate other expression types + + assert_eq!(Expr::Wildcard(None).to_string(), "*"); + assert_eq!( + Expr::Wildcard(Some(WildcardType::Field)).to_string(), + "*::field" + ); + assert_eq!(Expr::Distinct("foo".into()).to_string(), "DISTINCT foo"); + + // can't parse literal regular expressions as part of an arithmetic expression + assert_failure!(arithmetic_expression(r#""foo" + /^(no|match)$/"#)); + } + + /// Test call expressions using `ConditionalExpression` + fn call(i: &str) -> ParseResult<&str, Expr> { + call_expression::(i) + } + + #[test] + fn test_call() { + #[track_caller] + fn assert_call(input: &str, expected: &str) { + let (_, ex) = call(input).unwrap(); + assert_eq!(ex.to_string(), expected); + } + + // These tests validate a `Call` expression and also it's Display implementation. + // We don't need to validate Expr trees, as we do that in the conditional and arithmetic + // tests. + + // No arguments + assert_call("FN()", "fn()"); + + // Single argument with surrounding whitespace + assert_call("FN ( 1 )", "fn(1)"); + + // Multiple arguments with varying whitespace + assert_call("FN ( 1,2\n,3,\t4 )", "fn(1, 2, 3, 4)"); + + // Arguments as expressions + assert_call("FN ( 1 + 2, foo, 'bar' )", "fn(1 + 2, foo, 'bar')"); + + // A single regular expression argument + assert_call("FN ( /foo/ )", "fn(/foo/)"); + + // Fallible cases + + call("FN ( 1").unwrap_err(); + call("FN ( 1, )").unwrap_err(); + call("FN ( 1,, 2 )").unwrap_err(); + + // Conditionals not supported + call("FN ( 1 = 2 )").unwrap_err(); + + // Multiple regular expressions not supported + call("FN ( /foo/, /bar/ )").unwrap_err(); + } + + #[test] + fn test_var_ref_display() { + assert_eq!( + Expr::VarRef(VarRef { + name: "foo".into(), + data_type: None + }) + .to_string(), + "foo" + ); + assert_eq!( + Expr::VarRef(VarRef { + name: "foo".into(), + data_type: Some(VarRefDataType::Field) + }) + .to_string(), + "foo::field" + ); + } + + #[test] + fn test_var_ref_data_type() { + use VarRefDataType::*; + + // Ensure ordering of data types relative to one another. + + assert!(Float < Integer); + assert!(Integer < Unsigned); + assert!(Unsigned < String); + assert!(String < Boolean); + assert!(Boolean < Field); + assert!(Field < Tag); + + assert!(Float.is_field_type()); + assert!(Integer.is_field_type()); + assert!(Unsigned.is_field_type()); + assert!(String.is_field_type()); + assert!(Boolean.is_field_type()); + assert!(Field.is_field_type()); + assert!(Tag.is_tag_type()); + + assert!(!Float.is_tag_type()); + assert!(!Integer.is_tag_type()); + assert!(!Unsigned.is_tag_type()); + assert!(!String.is_tag_type()); + assert!(!Boolean.is_tag_type()); + assert!(!Field.is_tag_type()); + assert!(!Tag.is_field_type()); + + assert!(Float.is_numeric_type()); + assert!(Integer.is_numeric_type()); + assert!(Unsigned.is_numeric_type()); + assert!(!String.is_numeric_type()); + assert!(!Boolean.is_numeric_type()); + assert!(!Field.is_numeric_type()); + assert!(!Tag.is_numeric_type()); + } + + #[test] + fn test_binary_operator_reduce() { + use BinaryOperator::*; + + // + // Integer, Integer + // + + // Numeric operations + assert_eq!(Add.reduce(10, 2), 12); + assert_eq!(Sub.reduce(10, 2), 8); + assert_eq!(Mul.reduce(10, 2), 20); + assert_eq!(Div.reduce(10, 2), 5); + // Divide by zero yields zero + assert_eq!(Div.reduce(10, 0), 0); + assert_eq!(Mod.reduce(10, 2), 0); + // Bitwise operations + assert_eq!(BitwiseAnd.reduce(0b1111, 0b1010), 0b1010); + assert_eq!(BitwiseOr.reduce(0b0101, 0b1010), 0b1111); + assert_eq!(BitwiseXor.reduce(0b1101, 0b1010), 0b0111); + + // + // Float, Float + // + + assert_eq!(Add.try_reduce(10.0, 2.0).unwrap(), 12.0); + assert_eq!(Sub.try_reduce(10.0, 2.0).unwrap(), 8.0); + assert_eq!(Mul.try_reduce(10.0, 2.0).unwrap(), 20.0); + assert_eq!(Div.try_reduce(10.0, 2.0).unwrap(), 5.0); + // Divide by zero yields zero + assert_eq!(Div.try_reduce(10.0, 0.0).unwrap(), 0.0); + assert_eq!(Mod.try_reduce(10.0, 2.0).unwrap(), 0.0); + + // Bitwise operations + assert!(BitwiseAnd.try_reduce(1.0, 1.0).is_none()); + assert!(BitwiseOr.try_reduce(1.0, 1.0).is_none()); + assert!(BitwiseXor.try_reduce(1.0, 1.0).is_none()); + + // + // Float, Integer + // + + assert_eq!(Add.try_reduce(10.0, 2).unwrap(), 12.0); + assert_eq!(Sub.try_reduce(10.0, 2).unwrap(), 8.0); + assert_eq!(Mul.try_reduce(10.0, 2).unwrap(), 20.0); + assert_eq!(Div.try_reduce(10.0, 2).unwrap(), 5.0); + // Divide by zero yields zero + assert_eq!(Div.try_reduce(10.0, 0).unwrap(), 0.0); + assert_eq!(Mod.try_reduce(10.0, 2).unwrap(), 0.0); + } +} diff --git a/influxdb_influxql_parser/src/expression/conditional.rs b/influxdb_influxql_parser/src/expression/conditional.rs new file mode 100644 index 0000000..f34d696 --- /dev/null +++ b/influxdb_influxql_parser/src/expression/conditional.rs @@ -0,0 +1,631 @@ +use crate::common::{ws0, ParseError}; +use crate::expression::arithmetic::{ + arithmetic, call_expression, var_ref, ArithmeticParsers, Expr, +}; +use crate::expression::Call; +use crate::functions::is_scalar_math_function; +use crate::internal::{expect, verify, Error as InternalError, ParseResult}; +use crate::keywords::keyword; +use crate::literal::{literal_no_regex, literal_regex, Literal}; +use crate::parameter::parameter; +use crate::select::is_valid_now_call; +use nom::branch::alt; +use nom::bytes::complete::tag; +use nom::character::complete::char; +use nom::combinator::{map, value}; +use nom::multi::many0; +use nom::sequence::{delimited, preceded, tuple}; +use nom::Offset; +use std::fmt; +use std::fmt::{Display, Formatter, Write}; +use std::str::FromStr; + +/// Represents one of the conditional operators supported by [`ConditionalExpression::Binary`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConditionalOperator { + /// Represents the `=` operator. + Eq, + /// Represents the `!=` or `<>` operator. + NotEq, + /// Represents the `=~` (regular expression equals) operator. + EqRegex, + /// Represents the `!~` (regular expression not equals) operator. + NotEqRegex, + /// Represents the `<` operator. + Lt, + /// Represents the `<=` operator. + LtEq, + /// Represents the `>` operator. + Gt, + /// Represents the `>=` operator. + GtEq, + /// Represents the `IN` operator. + In, + /// Represents the `AND` operator. + And, + /// Represents the `OR` operator. + Or, +} + +impl Display for ConditionalOperator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Eq => f.write_char('='), + Self::NotEq => f.write_str("!="), + Self::EqRegex => f.write_str("=~"), + Self::NotEqRegex => f.write_str("!~"), + Self::Lt => f.write_char('<'), + Self::LtEq => f.write_str("<="), + Self::Gt => f.write_char('>'), + Self::GtEq => f.write_str(">="), + Self::In => f.write_str("IN"), + Self::And => f.write_str("AND"), + Self::Or => f.write_str("OR"), + } + } +} + +/// Conditional binary operations, such as `foo = 'bar'` or `true AND false`. +#[derive(Debug, Clone, PartialEq)] +pub struct ConditionalBinary { + /// Represents the left-hand side of the conditional binary expression. + pub lhs: Box, + /// Represents the operator to apply to the conditional binary expression. + pub op: ConditionalOperator, + /// Represents the right-hand side of the conditional binary expression. + pub rhs: Box, +} + +impl Display for ConditionalBinary { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let Self { lhs, op, rhs } = self; + write!(f, "{lhs} {op} {rhs}") + } +} + +/// Represents a conditional expression. +#[derive(Debug, Clone, PartialEq)] +pub enum ConditionalExpression { + /// Represents an arithmetic expression. + Expr(Box), + + /// Binary operations, such as `foo = 'bar'` or `true AND false`. + Binary(ConditionalBinary), + + /// Represents a conditional expression enclosed in parenthesis. + Grouped(Box), +} + +impl ConditionalExpression { + /// Returns the inner arithmetic [`Expr`]. + pub fn expr(&self) -> Option<&Expr> { + if let Self::Expr(expr) = self { + Some(expr) + } else { + None + } + } + + /// Return `self == other` + pub fn eq(self, other: Self) -> Self { + binary_cond(self, ConditionalOperator::Eq, other) + } + + /// Return `self != other` + pub fn not_eq(self, other: Self) -> Self { + binary_cond(self, ConditionalOperator::NotEq, other) + } + + /// Return `self > other` + pub fn gt(self, other: Self) -> Self { + binary_cond(self, ConditionalOperator::Gt, other) + } + + /// Return `self >= other` + pub fn gt_eq(self, other: Self) -> Self { + binary_cond(self, ConditionalOperator::GtEq, other) + } + + /// Return `self < other` + pub fn lt(self, other: Self) -> Self { + binary_cond(self, ConditionalOperator::Lt, other) + } + + /// Return `self <= other` + pub fn lt_eq(self, other: Self) -> Self { + binary_cond(self, ConditionalOperator::LtEq, other) + } + + /// Return `self AND other` + pub fn and(self, other: Self) -> Self { + binary_cond(self, ConditionalOperator::And, other) + } + + /// Return `self OR other` + pub fn or(self, other: Self) -> Self { + binary_cond(self, ConditionalOperator::Or, other) + } +} + +impl Display for ConditionalExpression { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Expr(v) => fmt::Display::fmt(v, f), + Self::Binary(v) => fmt::Display::fmt(v, f), + Self::Grouped(v) => write!(f, "({v})"), + } + } +} + +impl From for ConditionalExpression { + fn from(v: Literal) -> Self { + Self::Expr(Box::new(Expr::Literal(v))) + } +} + +/// Parse a parenthesis expression. +fn parens(i: &str) -> ParseResult<&str, ConditionalExpression> { + delimited( + preceded(ws0, char('(')), + map(conditional_expression, |e| { + ConditionalExpression::Grouped(e.into()) + }), + preceded(ws0, char(')')), + )(i) +} + +fn expr_or_group(i: &str) -> ParseResult<&str, ConditionalExpression> { + alt(( + map(arithmetic_expression, |v| { + ConditionalExpression::Expr(Box::new(v)) + }), + parens, + ))(i) +} + +/// Parse the conditional regular expression operators `=~` and `!~`. +fn conditional_regex(i: &str) -> ParseResult<&str, ConditionalExpression> { + let (input, f1) = expr_or_group(i)?; + let (input, exprs) = many0(tuple(( + preceded( + ws0, + alt(( + value(ConditionalOperator::EqRegex, tag("=~")), + value(ConditionalOperator::NotEqRegex, tag("!~")), + )), + ), + map( + expect( + "invalid conditional, expected regular expression", + preceded(ws0, literal_regex), + ), + From::from, + ), + )))(input)?; + Ok((input, reduce_expr(f1, exprs))) +} + +/// Parse conditional operators. +fn conditional(i: &str) -> ParseResult<&str, ConditionalExpression> { + let (input, f1) = conditional_regex(i)?; + let (input, exprs) = many0(tuple(( + preceded( + ws0, + alt(( + // try longest matches first + value(ConditionalOperator::LtEq, tag("<=")), + value(ConditionalOperator::GtEq, tag(">=")), + value(ConditionalOperator::NotEq, tag("!=")), + value(ConditionalOperator::NotEq, tag("<>")), + value(ConditionalOperator::Lt, char('<')), + value(ConditionalOperator::Gt, char('>')), + value(ConditionalOperator::Eq, char('=')), + )), + ), + expect("invalid conditional expression", conditional_regex), + )))(input)?; + Ok((input, reduce_expr(f1, exprs))) +} + +/// Parse conjunction operators, such as `AND`. +fn conjunction(i: &str) -> ParseResult<&str, ConditionalExpression> { + let (input, f1) = conditional(i)?; + let (input, exprs) = many0(tuple(( + value(ConditionalOperator::And, preceded(ws0, keyword("AND"))), + expect("invalid conditional expression", conditional), + )))(input)?; + Ok((input, reduce_expr(f1, exprs))) +} + +/// Parse disjunction operator, such as `OR`. +fn disjunction(i: &str) -> ParseResult<&str, ConditionalExpression> { + let (input, f1) = conjunction(i)?; + let (input, exprs) = many0(tuple(( + value(ConditionalOperator::Or, preceded(ws0, keyword("OR"))), + expect("invalid conditional expression", conjunction), + )))(input)?; + Ok((input, reduce_expr(f1, exprs))) +} + +/// Parse an InfluxQL conditional expression. +pub(crate) fn conditional_expression(i: &str) -> ParseResult<&str, ConditionalExpression> { + disjunction(i) +} + +/// Parse the input completely and return a [`ConditionalExpression`]. +/// +/// All leading and trailing whitespace is consumed. If any input remains after parsing, +/// an error is returned. +pub fn parse_conditional_expression(input: &str) -> Result { + let mut i: &str = input; + + // Consume whitespace from the input + (i, _) = ws0(i).expect("ws0 is infallible"); + + if i.is_empty() { + return Err(ParseError { + message: "unexpected eof".into(), + pos: 0, + }); + } + + let (mut i, cond) = match conditional_expression(i) { + Ok((i1, cond)) => (i1, cond), + Err(nom::Err::Failure(InternalError::Syntax { + input: pos, + message, + })) => { + return Err(ParseError { + message: message.into(), + pos: input.offset(pos), + }) + } + // any other error indicates an invalid expression + Err(_) => { + return Err(ParseError { + message: "invalid conditional expression".into(), + pos: input.offset(i), + }) + } + }; + + // Consume remaining whitespace from the input + (i, _) = ws0(i).expect("ws0 is infallible"); + + if !i.is_empty() { + return Err(ParseError { + message: "invalid conditional expression".into(), + pos: input.offset(i), + }); + } + + Ok(cond) +} + +impl FromStr for ConditionalExpression { + type Err = ParseError; + + fn from_str(s: &str) -> Result { + parse_conditional_expression(s) + } +} + +/// Folds `expr` and `remainder` into a [ConditionalExpression::Binary] tree. +fn reduce_expr( + expr: ConditionalExpression, + remainder: Vec<(ConditionalOperator, ConditionalExpression)>, +) -> ConditionalExpression { + remainder.into_iter().fold(expr, |lhs, val| { + ConditionalExpression::Binary(ConditionalBinary { + lhs: lhs.into(), + op: val.0, + rhs: val.1.into(), + }) + }) +} + +/// Returns true if `expr` is a valid [`Expr::Call`] expression for condtional expressions +/// in the WHERE clause. +pub(crate) fn is_valid_conditional_call(expr: &Expr) -> bool { + is_valid_now_call(expr) + || match expr { + Expr::Call(Call { name, .. }) => is_scalar_math_function(name), + _ => false, + } +} + +impl ConditionalExpression { + /// Parse the `now()` function call + fn call(i: &str) -> ParseResult<&str, Expr> { + verify( + "invalid expression, the only valid function calls are 'now' with no arguments, or scalar math functions", + call_expression::, + is_valid_conditional_call, + )(i) + } +} + +impl ArithmeticParsers for ConditionalExpression { + fn operand(i: &str) -> ParseResult<&str, Expr> { + preceded( + ws0, + alt(( + map(literal_no_regex, Expr::Literal), + Self::call, + var_ref, + map(parameter, Expr::BindParameter), + )), + )(i) + } +} + +/// Parse an arithmetic expression used by conditional expressions. +pub(crate) fn arithmetic_expression(i: &str) -> ParseResult<&str, Expr> { + arithmetic::(i) +} + +/// Return a new conditional expression, `lhs op rhs`. +pub fn binary_cond( + lhs: ConditionalExpression, + op: ConditionalOperator, + rhs: ConditionalExpression, +) -> ConditionalExpression { + ConditionalExpression::Binary(ConditionalBinary { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + }) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::expression::arithmetic::Expr; + use crate::{ + assert_expect_error, assert_failure, binary_op, call, cond_op, grouped, regex, var_ref, + }; + use test_helpers::assert_error; + + impl From for ConditionalExpression { + fn from(v: Expr) -> Self { + Self::Expr(Box::new(v)) + } + } + + impl From for Box { + fn from(v: i32) -> Self { + Self::new(ConditionalExpression::Expr(Box::new(Expr::Literal( + (v as i64).into(), + )))) + } + } + + impl From for Box { + fn from(v: i64) -> Self { + Self::new(ConditionalExpression::Expr(Box::new(Expr::Literal( + v.into(), + )))) + } + } + + impl From for Box { + fn from(v: u64) -> Self { + Self::new(ConditionalExpression::Expr(Box::new(Expr::Literal( + v.into(), + )))) + } + } + + impl From for Box { + fn from(v: Expr) -> Self { + Self::new(ConditionalExpression::Expr(v.into())) + } + } + + impl From> for Box { + fn from(v: Box) -> Self { + Self::new(ConditionalExpression::Expr(v)) + } + } + + #[test] + fn test_arithmetic_expression() { + // now() function call is permitted + let (_, got) = arithmetic_expression("now() + 3").unwrap(); + assert_eq!(got, binary_op!(call!("now"), Add, 3)); + + // arithmetic functions calls are permitted + let (_, got) = arithmetic_expression("abs(f) + 3").unwrap(); + assert_eq!(got, binary_op!(call!("abs", var_ref!("f")), Add, 3)); + + // Fallible cases + + assert_expect_error!( + arithmetic_expression("sum(foo)"), + "invalid expression, the only valid function calls are 'now' with no arguments, or scalar math functions" + ); + + assert_expect_error!( + arithmetic_expression("now(1)"), + "invalid expression, the only valid function calls are 'now' with no arguments, or scalar math functions" + ); + } + + #[test] + fn test_conditional_expression() { + let (_, got) = conditional_expression("foo = 5").unwrap(); + assert_eq!(got, *cond_op!(var_ref!("foo"), Eq, 5)); + + let (_, got) = conditional_expression("foo != 5").unwrap(); + assert_eq!(got, *cond_op!(var_ref!("foo"), NotEq, 5)); + + let (_, got) = conditional_expression("foo > 5").unwrap(); + assert_eq!(got, *cond_op!(var_ref!("foo"), Gt, 5)); + + let (_, got) = conditional_expression("foo >= 5").unwrap(); + assert_eq!(got, *cond_op!(var_ref!("foo"), GtEq, 5)); + + let (_, got) = conditional_expression("foo < 5").unwrap(); + assert_eq!(got, *cond_op!(var_ref!("foo"), Lt, 5)); + + let (_, got) = conditional_expression("foo <= 5").unwrap(); + assert_eq!(got, *cond_op!(var_ref!("foo"), LtEq, 5)); + + let (_, got) = conditional_expression("foo > 5 + 6 ").unwrap(); + assert_eq!(got, *cond_op!(var_ref!("foo"), Gt, binary_op!(5, Add, 6))); + + let (_, got) = conditional_expression("5 <= -6").unwrap(); + assert_eq!(got, *cond_op!(5, LtEq, -6)); + + // simple expressions + let (_, got) = conditional_expression("true").unwrap(); + assert_eq!( + got, + ConditionalExpression::Expr(Box::new(Expr::Literal(true.into()))) + ); + + // Expressions are still valid when whitespace is omitted + + let (_, got) = conditional_expression("foo>5+6 ").unwrap(); + assert_eq!(got, *cond_op!(var_ref!("foo"), Gt, binary_op!(5, Add, 6))); + + let (_, got) = conditional_expression("5<=-6").unwrap(); + assert_eq!(got, *cond_op!(5, LtEq, -6)); + + // var refs with cast operator + let (_, got) = conditional_expression("foo::integer = 5").unwrap(); + assert_eq!(got, *cond_op!(var_ref!("foo", Integer), Eq, 5)); + + // Fallible cases + + // conditional expression must be complete + assert_failure!(conditional_expression("5 <=")); + + // should not accept a regex literal + assert_failure!(conditional_expression("5 = /regex/")); + } + + #[test] + fn test_logical_expression() { + let (_, got) = conditional_expression("5 AND 6").unwrap(); + assert_eq!(got, *cond_op!(5, And, 6)); + + let (_, got) = conditional_expression("5 AND 6 OR 7").unwrap(); + assert_eq!(got, *cond_op!(cond_op!(5, And, 6), Or, 7)); + + let (_, got) = conditional_expression("5 > 3 OR 6 = 7 AND 7 != 1").unwrap(); + assert_eq!( + got, + *cond_op!( + cond_op!(5, Gt, 3), + Or, + cond_op!(cond_op!(6, Eq, 7), And, cond_op!(7, NotEq, 1)) + ) + ); + + let (_, got) = conditional_expression("5 AND (6 OR 7)").unwrap(); + assert_eq!(got, *cond_op!(5, And, grouped!(cond_op!(6, Or, 7)))); + + // <> is recognised as != + let (_, got) = conditional_expression("5 <> 6").unwrap(); + assert_eq!(got, *cond_op!(5, NotEq, 6)); + + // In the following cases, we validate that the `OR` keyword is not eagerly + // parsed from substrings + let (got, _) = conditional_expression("foo = bar ORDER BY time ASC").unwrap(); + assert_eq!(got, " ORDER BY time ASC"); + + let (got, _) = conditional_expression("foo = bar OR1").unwrap(); + assert_eq!(got, " OR1"); + + // Whitespace is optional for certain characters + let (got, _) = conditional_expression("foo = bar OR(foo > bar) ORDER BY time ASC").unwrap(); + assert_eq!(got, " ORDER BY time ASC"); + + // Fallible cases + + // Expects Expr after operator + assert_failure!(conditional_expression("5 OR -")); + assert_failure!(conditional_expression("5 OR")); + assert_failure!(conditional_expression("5 AND")); + + // Can't use "and" as identifier + assert_failure!(conditional_expression("5 AND and OR 5")); + } + + #[test] + fn test_regex() { + let (_, got) = conditional_expression("foo =~ /(a > b)/").unwrap(); + assert_eq!(got, *cond_op!(var_ref!("foo"), EqRegex, regex!("(a > b)"))); + + let (_, got) = conditional_expression("foo !~ /bar/").unwrap(); + assert_eq!(got, *cond_op!(var_ref!("foo"), NotEqRegex, regex!("bar"))); + + // Expressions are still valid when whitespace is omitted + + let (_, got) = conditional_expression("foo=~/(a > b)/").unwrap(); + assert_eq!(got, *cond_op!(var_ref!("foo"), EqRegex, regex!("(a > b)"))); + + // Fallible cases + + // Expects a regex literal after regex conditional operators + assert_expect_error!( + conditional_expression("foo =~ 5"), + "invalid conditional, expected regular expression" + ); + assert_expect_error!( + conditional_expression("foo !~ 5"), + "invalid conditional, expected regular expression" + ); + } + + #[test] + fn test_display_expr() { + let (_, e) = conditional_expression("foo = 'test'").unwrap(); + assert_eq!(e.to_string(), "foo = 'test'"); + } + + #[test] + fn test_parse_conditional_expression() { + assert_eq!( + parse_conditional_expression("a>b").unwrap().to_string(), + "a > b" + ); + + // with leading and trailing whitespace + assert_eq!( + parse_conditional_expression(" a>b ").unwrap().to_string(), + "a > b" + ); + + // Fallible cases + + // Expected regular expression + assert_error!(parse_conditional_expression("a =~ 'foo'"), ref e @ ParseError { .. } if e.pos == 4); + + // Invalid operator + assert_error!(parse_conditional_expression("a ~= /foo/"), ref e @ ParseError { .. } if e.pos == 2); + } + + /// Validate the [`FromStr`] implementation for [`ConditionalExpression`]. + #[test] + fn test_conditional_expression_parse() { + let cond = " a>b ".parse::().unwrap(); + assert_eq!(cond.to_string(), "a > b"); + } + + #[test] + fn test_conditional_expression_expr() { + let cond: ConditionalExpression = "a + 1 > b - 2".parse().unwrap(); + assert!(cond.expr().is_none()); + + let cond: ConditionalExpression = "(a + 1 > b - 2)".parse().unwrap(); + assert!(cond.expr().is_none()); + + let cond: ConditionalExpression = "a + 1".parse().unwrap(); + assert_eq!(cond.expr().unwrap().to_string(), "a + 1"); + + let cond: ConditionalExpression = "(a + 1)".parse().unwrap(); + assert_eq!(cond.expr().unwrap().to_string(), "(a + 1)"); + } +} diff --git a/influxdb_influxql_parser/src/expression/snapshots/influxdb_influxql_parser__expression__walk__test__walk_expr-2.snap b/influxdb_influxql_parser/src/expression/snapshots/influxdb_influxql_parser__expression__walk__test__walk_expr-2.snap new file mode 100644 index 0000000..625a695 --- /dev/null +++ b/influxdb_influxql_parser/src/expression/snapshots/influxdb_influxql_parser__expression__walk__test__walk_expr-2.snap @@ -0,0 +1,7 @@ +--- +source: influxdb_influxql_parser/src/expression/walk.rs +expression: "walk_expr(\"now() + 1h\")" +--- +0: Call(Call { name: "now", args: [] }) +1: Literal(Duration(Duration(3600000000000))) +2: Binary(Binary { lhs: Call(Call { name: "now", args: [] }), op: Add, rhs: Literal(Duration(Duration(3600000000000))) }) diff --git a/influxdb_influxql_parser/src/expression/snapshots/influxdb_influxql_parser__expression__walk__test__walk_expr.snap b/influxdb_influxql_parser/src/expression/snapshots/influxdb_influxql_parser__expression__walk__test__walk_expr.snap new file mode 100644 index 0000000..219abf1 --- /dev/null +++ b/influxdb_influxql_parser/src/expression/snapshots/influxdb_influxql_parser__expression__walk__test__walk_expr.snap @@ -0,0 +1,7 @@ +--- +source: influxdb_influxql_parser/src/expression/walk.rs +expression: "walk_expr(\"5 + 6\")" +--- +0: Literal(Integer(5)) +1: Literal(Integer(6)) +2: Binary(Binary { lhs: Literal(Integer(5)), op: Add, rhs: Literal(Integer(6)) }) diff --git a/influxdb_influxql_parser/src/expression/snapshots/influxdb_influxql_parser__expression__walk__test__walk_expr_mut-2.snap b/influxdb_influxql_parser/src/expression/snapshots/influxdb_influxql_parser__expression__walk__test__walk_expr_mut-2.snap new file mode 100644 index 0000000..27cd9cc --- /dev/null +++ b/influxdb_influxql_parser/src/expression/snapshots/influxdb_influxql_parser__expression__walk__test__walk_expr_mut-2.snap @@ -0,0 +1,7 @@ +--- +source: influxdb_influxql_parser/src/expression/walk.rs +expression: "walk_expr_mut(\"now() + 1h\")" +--- +0: Call(Call { name: "now", args: [] }) +1: Literal(Duration(Duration(3600000000000))) +2: Binary(Binary { lhs: Call(Call { name: "now", args: [] }), op: Add, rhs: Literal(Duration(Duration(3600000000000))) }) diff --git a/influxdb_influxql_parser/src/expression/snapshots/influxdb_influxql_parser__expression__walk__test__walk_expr_mut.snap b/influxdb_influxql_parser/src/expression/snapshots/influxdb_influxql_parser__expression__walk__test__walk_expr_mut.snap new file mode 100644 index 0000000..6eb590b --- /dev/null +++ b/influxdb_influxql_parser/src/expression/snapshots/influxdb_influxql_parser__expression__walk__test__walk_expr_mut.snap @@ -0,0 +1,7 @@ +--- +source: influxdb_influxql_parser/src/expression/walk.rs +expression: "walk_expr_mut(\"5 + 6\")" +--- +0: Literal(Integer(5)) +1: Literal(Integer(6)) +2: Binary(Binary { lhs: Literal(Integer(5)), op: Add, rhs: Literal(Integer(6)) }) diff --git a/influxdb_influxql_parser/src/expression/snapshots/influxdb_influxql_parser__expression__walk__test__walk_expression-2.snap b/influxdb_influxql_parser/src/expression/snapshots/influxdb_influxql_parser__expression__walk__test__walk_expression-2.snap new file mode 100644 index 0000000..6ff8ee4 --- /dev/null +++ b/influxdb_influxql_parser/src/expression/snapshots/influxdb_influxql_parser__expression__walk__test__walk_expression-2.snap @@ -0,0 +1,11 @@ +--- +source: influxdb_influxql_parser/src/expression/walk.rs +expression: "walk_expression(\"time > now() + 1h\")" +--- +0: Arithmetic(VarRef(VarRef { name: Identifier("time"), data_type: None })) +1: Conditional(Expr(VarRef(VarRef { name: Identifier("time"), data_type: None }))) +2: Arithmetic(Call(Call { name: "now", args: [] })) +3: Arithmetic(Literal(Duration(Duration(3600000000000)))) +4: Arithmetic(Binary(Binary { lhs: Call(Call { name: "now", args: [] }), op: Add, rhs: Literal(Duration(Duration(3600000000000))) })) +5: Conditional(Expr(Binary(Binary { lhs: Call(Call { name: "now", args: [] }), op: Add, rhs: Literal(Duration(Duration(3600000000000))) }))) +6: Conditional(Binary(ConditionalBinary { lhs: Expr(VarRef(VarRef { name: Identifier("time"), data_type: None })), op: Gt, rhs: Expr(Binary(Binary { lhs: Call(Call { name: "now", args: [] }), op: Add, rhs: Literal(Duration(Duration(3600000000000))) })) })) diff --git a/influxdb_influxql_parser/src/expression/snapshots/influxdb_influxql_parser__expression__walk__test__walk_expression.snap b/influxdb_influxql_parser/src/expression/snapshots/influxdb_influxql_parser__expression__walk__test__walk_expression.snap new file mode 100644 index 0000000..aae7068 --- /dev/null +++ b/influxdb_influxql_parser/src/expression/snapshots/influxdb_influxql_parser__expression__walk__test__walk_expression.snap @@ -0,0 +1,13 @@ +--- +source: influxdb_influxql_parser/src/expression/walk.rs +expression: "walk_expression(\"5 + 6 = 2 + 9\")" +--- +0: Arithmetic(Literal(Integer(5))) +1: Arithmetic(Literal(Integer(6))) +2: Arithmetic(Binary(Binary { lhs: Literal(Integer(5)), op: Add, rhs: Literal(Integer(6)) })) +3: Conditional(Expr(Binary(Binary { lhs: Literal(Integer(5)), op: Add, rhs: Literal(Integer(6)) }))) +4: Arithmetic(Literal(Integer(2))) +5: Arithmetic(Literal(Integer(9))) +6: Arithmetic(Binary(Binary { lhs: Literal(Integer(2)), op: Add, rhs: Literal(Integer(9)) })) +7: Conditional(Expr(Binary(Binary { lhs: Literal(Integer(2)), op: Add, rhs: Literal(Integer(9)) }))) +8: Conditional(Binary(ConditionalBinary { lhs: Expr(Binary(Binary { lhs: Literal(Integer(5)), op: Add, rhs: Literal(Integer(6)) })), op: Eq, rhs: Expr(Binary(Binary { lhs: Literal(Integer(2)), op: Add, rhs: Literal(Integer(9)) })) })) diff --git a/influxdb_influxql_parser/src/expression/test_util.rs b/influxdb_influxql_parser/src/expression/test_util.rs new file mode 100644 index 0000000..2d18de4 --- /dev/null +++ b/influxdb_influxql_parser/src/expression/test_util.rs @@ -0,0 +1,130 @@ +#![cfg(test)] + +/// Constructs an [crate::expression::arithmetic::Expr::VarRef] expression. +#[macro_export] +macro_rules! var_ref { + ($NAME: literal) => { + $crate::expression::Expr::VarRef($crate::expression::VarRef { + name: $NAME.into(), + data_type: None, + }) + }; + + ($NAME: literal, $TYPE: ident) => { + $crate::expression::Expr::VarRef($crate::expression::VarRef { + name: $NAME.into(), + data_type: Some($crate::expression::arithmetic::VarRefDataType::$TYPE), + }) + }; +} + +/// Constructs a regular expression [crate::expression::arithmetic::Expr::Literal]. +#[macro_export] +macro_rules! regex { + ($EXPR: expr) => { + $crate::expression::arithmetic::Expr::Literal( + $crate::literal::Literal::Regex($EXPR.into()).into(), + ) + }; +} + +/// Constructs a [crate::expression::arithmetic::Expr::BindParameter] expression. +#[macro_export] +macro_rules! param { + ($EXPR: expr) => { + $crate::expression::arithmetic::Expr::BindParameter( + $crate::parameter::BindParameter::new($EXPR.into()).into(), + ) + }; +} + +/// Constructs a [crate::expression::conditional::ConditionalExpression::Grouped] expression. +#[macro_export] +macro_rules! grouped { + ($EXPR: expr) => { + <$crate::expression::conditional::ConditionalExpression as std::convert::Into< + Box<$crate::expression::conditional::ConditionalExpression>, + >>::into($crate::expression::conditional::ConditionalExpression::Grouped($EXPR.into())) + }; +} + +/// Constructs a [crate::expression::arithmetic::Expr::Nested] expression. +#[macro_export] +macro_rules! nested { + ($EXPR: expr) => { + <$crate::expression::arithmetic::Expr as std::convert::Into< + Box<$crate::expression::arithmetic::Expr>, + >>::into($crate::expression::arithmetic::Expr::Nested($EXPR.into())) + }; +} + +/// Constructs a [crate::expression::arithmetic::Expr::Call] expression. +#[macro_export] +macro_rules! call { + ($NAME:literal) => { + $crate::expression::Expr::Call($crate::expression::Call { + name: $NAME.into(), + args: vec![], + }) + }; + ($NAME:literal, $( $ARG:expr ),+) => { + $crate::expression::Expr::Call($crate::expression::Call { + name: $NAME.into(), + args: vec![$( $ARG ),+], + }) + }; +} + +/// Constructs a [crate::expression::arithmetic::Expr::Distinct] expression. +#[macro_export] +macro_rules! distinct { + ($IDENT:literal) => { + $crate::expression::arithmetic::Expr::Distinct($IDENT.into()) + }; +} + +/// Constructs a [crate::expression::arithmetic::Expr::Wildcard] expression. +#[macro_export] +macro_rules! wildcard { + () => { + $crate::expression::arithmetic::Expr::Wildcard(None) + }; + (tag) => { + $crate::expression::arithmetic::Expr::Wildcard(Some( + $crate::expression::arithmetic::WildcardType::Tag, + )) + }; + (field) => { + $crate::expression::arithmetic::Expr::Wildcard(Some( + $crate::expression::arithmetic::WildcardType::Field, + )) + }; +} + +/// Constructs a [crate::expression::arithmetic::Expr::Binary] expression. +#[macro_export] +macro_rules! binary_op { + ($LHS: expr, $OP: ident, $RHS: expr) => { + $crate::expression::Expr::Binary($crate::expression::Binary { + lhs: $LHS.into(), + op: $crate::expression::BinaryOperator::$OP, + rhs: $RHS.into(), + }) + }; +} + +/// Constructs a [crate::expression::conditional::ConditionalExpression::Binary] expression. +#[macro_export] +macro_rules! cond_op { + ($LHS: expr, $OP: ident, $RHS: expr) => { + <$crate::expression::ConditionalExpression as std::convert::Into< + Box<$crate::expression::ConditionalExpression>, + >>::into($crate::expression::ConditionalExpression::Binary( + $crate::expression::ConditionalBinary { + lhs: $LHS.into(), + op: $crate::expression::ConditionalOperator::$OP, + rhs: $RHS.into(), + }, + )) + }; +} diff --git a/influxdb_influxql_parser/src/expression/walk.rs b/influxdb_influxql_parser/src/expression/walk.rs new file mode 100644 index 0000000..352633d --- /dev/null +++ b/influxdb_influxql_parser/src/expression/walk.rs @@ -0,0 +1,205 @@ +use crate::expression::{Binary, Call, ConditionalBinary, ConditionalExpression, Expr}; + +/// Expression distinguishes InfluxQL [`ConditionalExpression`] or [`Expr`] +/// nodes when visiting a [`ConditionalExpression`] tree. See [`walk_expression`]. +#[derive(Debug)] +pub enum Expression<'a> { + /// Specifies a conditional expression. + Conditional(&'a ConditionalExpression), + /// Specifies an arithmetic expression. + Arithmetic(&'a Expr), +} + +/// ExpressionMut is the same as [`Expression`] with the exception that +/// it provides mutable access to the nodes of the tree. +#[derive(Debug)] +pub enum ExpressionMut<'a> { + /// Specifies a conditional expression. + Conditional(&'a mut ConditionalExpression), + /// Specifies an arithmetic expression. + Arithmetic(&'a mut Expr), +} + +/// Perform a depth-first traversal of an expression tree. +pub fn walk_expression<'a, B>( + node: &'a ConditionalExpression, + visit: &mut impl FnMut(Expression<'a>) -> std::ops::ControlFlow, +) -> std::ops::ControlFlow { + match node { + ConditionalExpression::Expr(n) => walk_expr(n, &mut |n| visit(Expression::Arithmetic(n)))?, + ConditionalExpression::Binary(ConditionalBinary { lhs, rhs, .. }) => { + walk_expression(lhs, visit)?; + walk_expression(rhs, visit)?; + } + ConditionalExpression::Grouped(n) => walk_expression(n, visit)?, + } + + visit(Expression::Conditional(node)) +} + +/// Perform a depth-first traversal of a mutable arithmetic or conditional expression tree. +pub fn walk_expression_mut( + node: &mut ConditionalExpression, + visit: &mut impl FnMut(ExpressionMut<'_>) -> std::ops::ControlFlow, +) -> std::ops::ControlFlow { + match node { + ConditionalExpression::Expr(n) => { + walk_expr_mut(n, &mut |n| visit(ExpressionMut::Arithmetic(n)))? + } + ConditionalExpression::Binary(ConditionalBinary { lhs, rhs, .. }) => { + walk_expression_mut(lhs, visit)?; + walk_expression_mut(rhs, visit)?; + } + ConditionalExpression::Grouped(n) => walk_expression_mut(n, visit)?, + } + + visit(ExpressionMut::Conditional(node)) +} + +/// Perform a depth-first traversal of the arithmetic expression tree. +pub fn walk_expr<'a, B>( + expr: &'a Expr, + visit: &mut impl FnMut(&'a Expr) -> std::ops::ControlFlow, +) -> std::ops::ControlFlow { + match expr { + Expr::Binary(Binary { lhs, rhs, .. }) => { + walk_expr(lhs, visit)?; + walk_expr(rhs, visit)?; + } + Expr::Nested(n) => walk_expr(n, visit)?, + Expr::Call(Call { args, .. }) => { + args.iter().try_for_each(|n| walk_expr(n, visit))?; + } + Expr::VarRef { .. } + | Expr::BindParameter(_) + | Expr::Literal(_) + | Expr::Wildcard(_) + | Expr::Distinct(_) => {} + } + + visit(expr) +} + +/// Perform a depth-first traversal of a mutable arithmetic expression tree. +pub fn walk_expr_mut( + expr: &mut Expr, + visit: &mut impl FnMut(&mut Expr) -> std::ops::ControlFlow, +) -> std::ops::ControlFlow { + match expr { + Expr::Binary(Binary { lhs, rhs, .. }) => { + walk_expr_mut(lhs, visit)?; + walk_expr_mut(rhs, visit)?; + } + Expr::Nested(n) => walk_expr_mut(n, visit)?, + Expr::Call(Call { args, .. }) => { + args.iter_mut().try_for_each(|n| walk_expr_mut(n, visit))?; + } + Expr::VarRef { .. } + | Expr::BindParameter(_) + | Expr::Literal(_) + | Expr::Wildcard(_) + | Expr::Distinct(_) => {} + } + + visit(expr) +} + +#[cfg(test)] +mod test { + use crate::expression::walk::{walk_expr_mut, walk_expression_mut, ExpressionMut}; + use crate::expression::{ + arithmetic_expression, conditional_expression, ConditionalBinary, ConditionalExpression, + ConditionalOperator, Expr, VarRef, + }; + use crate::literal::Literal; + + #[test] + fn test_walk_expression() { + fn walk_expression(s: &str) -> String { + let (_, ref expr) = conditional_expression(s).unwrap(); + let mut calls = Vec::new(); + let mut call_no = 0; + super::walk_expression::<()>(expr, &mut |n| { + calls.push(format!("{call_no}: {n:?}")); + call_no += 1; + std::ops::ControlFlow::Continue(()) + }); + calls.join("\n") + } + + insta::assert_display_snapshot!(walk_expression("5 + 6 = 2 + 9")); + insta::assert_display_snapshot!(walk_expression("time > now() + 1h")); + } + + #[test] + fn test_walk_expression_mut_modify() { + let (_, ref mut expr) = conditional_expression("foo + bar + 5 =~ /str/").unwrap(); + walk_expression_mut::<()>(expr, &mut |e| { + match e { + ExpressionMut::Arithmetic(n) => match n { + Expr::VarRef(VarRef { name, .. }) => *name = format!("c_{name}").into(), + Expr::Literal(Literal::Integer(v)) => *v *= 10, + Expr::Literal(Literal::Regex(v)) => *v = format!("c_{}", v.0).into(), + _ => {} + }, + ExpressionMut::Conditional(n) => { + if let ConditionalExpression::Binary(ConditionalBinary { op, .. }) = n { + *op = ConditionalOperator::NotEqRegex + } + } + } + std::ops::ControlFlow::Continue(()) + }); + assert_eq!(expr.to_string(), "c_foo + c_bar + 50 !~ /c_str/") + } + + #[test] + fn test_walk_expr() { + fn walk_expr(s: &str) -> String { + let (_, expr) = arithmetic_expression(s).unwrap(); + let mut calls = Vec::new(); + let mut call_no = 0; + super::walk_expr::<()>(&expr, &mut |n| { + calls.push(format!("{call_no}: {n:?}")); + call_no += 1; + std::ops::ControlFlow::Continue(()) + }); + calls.join("\n") + } + + insta::assert_display_snapshot!(walk_expr("5 + 6")); + insta::assert_display_snapshot!(walk_expr("now() + 1h")); + } + + #[test] + fn test_walk_expr_mut() { + fn walk_expr_mut(s: &str) -> String { + let (_, mut expr) = arithmetic_expression(s).unwrap(); + let mut calls = Vec::new(); + let mut call_no = 0; + super::walk_expr_mut::<()>(&mut expr, &mut |n| { + calls.push(format!("{call_no}: {n:?}")); + call_no += 1; + std::ops::ControlFlow::Continue(()) + }); + calls.join("\n") + } + + insta::assert_display_snapshot!(walk_expr_mut("5 + 6")); + insta::assert_display_snapshot!(walk_expr_mut("now() + 1h")); + } + + #[test] + fn test_walk_expr_mut_modify() { + let (_, mut expr) = arithmetic_expression("foo + bar + 5").unwrap(); + walk_expr_mut::<()>(&mut expr, &mut |e| { + match e { + Expr::VarRef(VarRef { name, .. }) => *name = format!("c_{name}").into(), + Expr::Literal(Literal::Integer(v)) => *v *= 10, + _ => {} + } + std::ops::ControlFlow::Continue(()) + }); + assert_eq!(expr.to_string(), "c_foo + c_bar + 50") + } +} diff --git a/influxdb_influxql_parser/src/functions.rs b/influxdb_influxql_parser/src/functions.rs new file mode 100644 index 0000000..b42103e --- /dev/null +++ b/influxdb_influxql_parser/src/functions.rs @@ -0,0 +1,74 @@ +//! # [Functions] supported by InfluxQL +//! +//! [Functions]: https://docs.influxdata.com/influxdb/v1.8/query_language/functions/ + +use std::collections::HashSet; + +use once_cell::sync::Lazy; + +/// Returns `true` if `name` is a mathematical scalar function +/// supported by InfluxQL. +pub fn is_scalar_math_function(name: &str) -> bool { + static FUNCTIONS: Lazy> = Lazy::new(|| { + HashSet::from([ + "abs", "sin", "cos", "tan", "asin", "acos", "atan", "atan2", "exp", "log", "ln", + "log2", "log10", "sqrt", "pow", "floor", "ceil", "round", + ]) + }); + + FUNCTIONS.contains(name) +} + +/// Returns `true` if `name` is an aggregate or aggregate function +/// supported by InfluxQL. +pub fn is_aggregate_function(name: &str) -> bool { + static FUNCTIONS: Lazy> = Lazy::new(|| { + HashSet::from([ + // Scalar-like functions + "cumulative_sum", + "derivative", + "difference", + "elapsed", + "moving_average", + "non_negative_derivative", + "non_negative_difference", + // Selector functions + "bottom", + "first", + "last", + "max", + "min", + "percentile", + "sample", + "top", + // Aggregate functions + "count", + "integral", + "mean", + "median", + "mode", + "spread", + "stddev", + "sum", + // Prediction functions + "holt_winters", + "holt_winters_with_fit", + // Technical analysis functions + "chande_momentum_oscillator", + "exponential_moving_average", + "double_exponential_moving_average", + "kaufmans_efficiency_ratio", + "kaufmans_adaptive_moving_average", + "triple_exponential_moving_average", + "triple_exponential_derivative", + "relative_strength_index", + ]) + }); + + FUNCTIONS.contains(name) +} + +/// Returns `true` if `name` is `"now"`. +pub fn is_now_function(name: &str) -> bool { + name == "now" +} diff --git a/influxdb_influxql_parser/src/identifier.rs b/influxdb_influxql_parser/src/identifier.rs new file mode 100644 index 0000000..dcbc2fb --- /dev/null +++ b/influxdb_influxql_parser/src/identifier.rs @@ -0,0 +1,167 @@ +//! # Parse an InfluxQL [identifier] +//! +//! Identifiers are parsed using the following rules: +//! +//! * double quoted identifiers can contain any unicode character other than a new line +//! * double quoted identifiers can contain escaped characters, namely `\"`, `\n`, `\t`, `\\` and `\'` +//! * double quoted identifiers can contain [InfluxQL keywords][keywords] +//! * unquoted identifiers must start with an upper or lowercase ASCII character or `_` +//! * unquoted identifiers may contain only ASCII letters, decimal digits, and `_` +//! * identifiers may be preceded by whitespace +//! +//! [identifier]: https://docs.influxdata.com/influxdb/v1.8/query_language/spec/#identifiers +//! [keywords]: https://docs.influxdata.com/influxdb/v1.8/query_language/spec/#keywords + +use crate::common::ws0; +use crate::internal::ParseResult; +use crate::keywords::sql_keyword; +use crate::string::double_quoted_string; +use crate::{impl_tuple_clause, write_quoted_string}; +use nom::branch::alt; +use nom::bytes::complete::tag; +use nom::character::complete::{alpha1, alphanumeric1}; +use nom::combinator::{map, not, recognize}; +use nom::multi::many0_count; +use nom::sequence::{pair, preceded}; +use std::fmt::{Display, Formatter, Write}; +use std::{fmt, mem}; + +/// Parse an unquoted InfluxQL identifier. +pub(crate) fn unquoted_identifier(i: &str) -> ParseResult<&str, &str> { + preceded( + not(sql_keyword), + recognize(pair( + alt((alpha1, tag("_"))), + many0_count(alt((alphanumeric1, tag("_")))), + )), + )(i) +} + +/// A type that represents an InfluxQL identifier. +#[derive(Clone, Debug, Eq, Hash, PartialEq, Ord, PartialOrd)] +pub struct Identifier(pub(crate) String); + +impl_tuple_clause!(Identifier, String); + +impl From<&str> for Identifier { + fn from(s: &str) -> Self { + Self(s.to_string()) + } +} + +impl Identifier { + /// Returns true if the identifier requires quotes. + pub fn requires_quotes(&self) -> bool { + nom::sequence::terminated(unquoted_identifier, nom::combinator::eof)(&self.0).is_err() + } + + /// Takes the string value out of the identifier, leaving a default string value in its place. + pub fn take(&mut self) -> String { + mem::take(&mut self.0) + } +} + +impl Display for Identifier { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write_quoted_string!(f, '"', self.0.as_str(), unquoted_identifier, '\n' => "\\n", '\\' => "\\\\", '"' => "\\\""); + Ok(()) + } +} + +/// Parses an InfluxQL [Identifier]. +/// +/// EBNF for an identifier is approximately: +/// +/// ```text +/// identifier ::= whitespace? ( quoted_identifier | unquoted_identifier ) +/// unquoted_identifier ::= [_a..zA..Z] [_a..zA..Z0..9]* +/// quoted_identifier ::= '"' [^"\n] '"' +/// ``` +pub(crate) fn identifier(i: &str) -> ParseResult<&str, Identifier> { + // See: https://github.com/influxdata/influxql/blob/7e7d61973256ffeef4b99edd0a89f18a9e52fa2d/parser.go#L432-L438 + preceded( + ws0, + alt(( + map(unquoted_identifier, Into::into), + map(double_quoted_string, Into::into), + )), + )(i) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_unquoted_identifier() { + // all ascii + let (_, got) = unquoted_identifier("cpu").unwrap(); + assert_eq!(got, "cpu"); + + // all valid chars + let (_, got) = unquoted_identifier("cpu_0").unwrap(); + assert_eq!(got, "cpu_0"); + + // begin with underscore + let (_, got) = unquoted_identifier("_cpu_0").unwrap(); + assert_eq!(got, "_cpu_0"); + + // ┌─────────────────────────────┐ + // │ Fallible tests │ + // └─────────────────────────────┘ + + // start with number + unquoted_identifier("0cpu").unwrap_err(); + + // is a keyword + unquoted_identifier("as").unwrap_err(); + } + + #[test] + fn test_identifier() { + // quoted + let (_, got) = identifier("\"quick draw\"").unwrap(); + assert_eq!(got, "quick draw".into()); + // validate that `as_str` returns the unquoted string + assert_eq!(got.as_str(), "quick draw"); + + // unquoted + let (_, got) = identifier("quick_draw").unwrap(); + assert_eq!(got, "quick_draw".into()); + + // leading whitespace + let (_, got) = identifier(" quick_draw").unwrap(); + assert_eq!(got, "quick_draw".into()); + } + + #[test] + fn test_identifier_display() { + // Identifier properly escapes specific characters and quotes output + let got = Identifier("quick\n\t\\\"'draw \u{1f47d}".into()).to_string(); + assert_eq!(got, r#""quick\n \\\"'draw 👽""#); + + // Identifier displays unquoted output + let got = Identifier("quick_draw".into()).to_string(); + assert_eq!(got, "quick_draw"); + } + + #[test] + fn test_identifier_requires_quotes() { + // Following examples require quotes + + // Quotes, spaces, non-ASCII + assert!(Identifier("quick\n\t\\\"'draw \u{1f47d}".into()).requires_quotes()); + // non-ASCII + assert!(Identifier("quick_\u{1f47d}".into()).requires_quotes()); + // starts with number + assert!(Identifier("0quick".into()).requires_quotes()); + + // Following examples do not require quotes + + // starts with underscore + assert!(!Identifier("_quick".into()).requires_quotes()); + + // Only ASCII, non-space + assert!(!Identifier("quick_90".into()).requires_quotes()); + } +} diff --git a/influxdb_influxql_parser/src/internal.rs b/influxdb_influxql_parser/src/internal.rs new file mode 100644 index 0000000..90b2f59 --- /dev/null +++ b/influxdb_influxql_parser/src/internal.rs @@ -0,0 +1,133 @@ +//! Internal result and error types used to build InfluxQL parsers +//! +use nom::error::{ErrorKind as NomErrorKind, ParseError as NomParseError}; +use nom::Parser; +use std::borrow::Borrow; +use std::fmt::{Display, Formatter}; + +/// This trait must be implemented in order to use the [`map_fail`] and +/// [`expect`] functions for generating user-friendly error messages. +pub(crate) trait ParseError<'a>: NomParseError<&'a str> + Sized { + fn from_message(input: &'a str, message: &'static str) -> Self; +} + +/// An internal error type used to build InfluxQL parsers. +#[derive(Debug, PartialEq, Eq)] +pub enum Error { + Syntax { input: I, message: &'static str }, + Nom(I, NomErrorKind), +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Syntax { input: _, message } => { + write!(f, "Syntax error: {message}") + } + Self::Nom(_, kind) => write!(f, "nom error: {kind:?}"), + } + } +} + +impl<'a> ParseError<'a> for Error<&'a str> { + fn from_message(input: &'a str, message: &'static str) -> Self { + Self::Syntax { input, message } + } +} + +/// Applies a function returning a [`ParseResult`] over the result of the `parser`. +/// If the parser returns an error, the result will be mapped to an unrecoverable +/// [`nom::Err::Failure`] with the specified `message` for additional context. +pub(crate) fn map_fail<'a, O1, O2, E: ParseError<'a>, E2, F, G>( + message: &'static str, + mut parser: F, + mut f: G, +) -> impl FnMut(&'a str) -> ParseResult<&'a str, O2, E> +where + F: Parser<&'a str, O1, E>, + G: FnMut(O1) -> Result, +{ + move |input| { + let (input, o1) = parser.parse(input)?; + match f(o1) { + Ok(o2) => Ok((input, o2)), + Err(_) => Err(nom::Err::Failure(E::from_message(input, message))), + } + } +} + +/// Applies a function returning a [`ParseResult`] over the result of the `parser`. +/// If the parser returns an error, the result will be mapped to a recoverable +/// [`nom::Err::Error`] with the specified `message` for additional context. +pub(crate) fn map_error<'a, O1, O2, E: ParseError<'a>, E2, F, G>( + message: &'static str, + mut parser: F, + mut f: G, +) -> impl FnMut(&'a str) -> ParseResult<&'a str, O2, E> +where + F: Parser<&'a str, O1, E>, + G: FnMut(O1) -> Result, +{ + move |input| { + let (input, o1) = parser.parse(input)?; + match f(o1) { + Ok(o2) => Ok((input, o2)), + Err(_) => Err(nom::Err::Error(E::from_message(input, message))), + } + } +} + +/// Transforms a [`nom::Err::Error`] to a [`nom::Err::Failure`] using `message` for additional +/// context. +pub(crate) fn expect<'a, E: ParseError<'a>, F, O>( + message: &'static str, + mut f: F, +) -> impl FnMut(&'a str) -> ParseResult<&'a str, O, E> +where + F: Parser<&'a str, O, E>, +{ + move |i| match f.parse(i) { + Ok(o) => Ok(o), + Err(nom::Err::Incomplete(i)) => Err(nom::Err::Incomplete(i)), + Err(nom::Err::Error(_)) => Err(nom::Err::Failure(E::from_message(i, message))), + Err(nom::Err::Failure(e)) => Err(nom::Err::Failure(e)), + } +} + +/// Returns the result of `f` if it satisfies `is_valid`; otherwise, +/// returns an error using the specified `message`. +pub(crate) fn verify<'a, O1, O2, E: ParseError<'a>, F, G>( + message: &'static str, + mut f: F, + is_valid: G, +) -> impl FnMut(&'a str) -> ParseResult<&'a str, O1, E> +where + F: Parser<&'a str, O1, E>, + G: Fn(&O2) -> bool, + O1: Borrow, + O2: ?Sized, +{ + move |i: &str| { + let (remain, o) = f.parse(i)?; + + if is_valid(o.borrow()) { + Ok((remain, o)) + } else { + Err(nom::Err::Failure(E::from_message(i, message))) + } + } +} + +impl NomParseError for Error { + fn from_error_kind(input: I, kind: NomErrorKind) -> Self { + Self::Nom(input, kind) + } + + fn append(_: I, _: NomErrorKind, other: Self) -> Self { + other + } +} + +/// ParseResult is a type alias for [`nom::IResult`] used by nom combinator +/// functions for parsing InfluxQL. +pub(crate) type ParseResult> = nom::IResult; diff --git a/influxdb_influxql_parser/src/keywords.rs b/influxdb_influxql_parser/src/keywords.rs new file mode 100644 index 0000000..d665245 --- /dev/null +++ b/influxdb_influxql_parser/src/keywords.rs @@ -0,0 +1,353 @@ +//! # Parse InfluxQL [keywords] +//! +//! [keywords]: https://docs.influxdata.com/influxdb/v1.8/query_language/spec/#keywords + +use crate::internal::ParseResult; +use nom::bytes::complete::tag_no_case; +use nom::character::complete::alpha1; +use nom::combinator::{fail, verify}; +use nom::sequence::terminated; +use nom::FindToken; +use once_cell::sync::Lazy; +use std::collections::HashSet; +use std::hash::{Hash, Hasher}; + +/// Verifies the next character of `i` is valid following a keyword. +/// +/// Keywords may be followed by whitespace, statement terminator (;), parens, +/// or conditional and arithmetic operators or EOF +fn keyword_follow_char(i: &str) -> ParseResult<&str, ()> { + if i.is_empty() || b" \n\t;(),=!><+-/*|&^%".find_token(i.bytes().next().unwrap()) { + Ok((i, ())) + } else { + fail(i) + } +} + +/// Token represents a string with case-insensitive ordering and equality. +#[derive(Debug, Clone)] +pub(crate) struct Token<'a>(pub(crate) &'a str); + +impl PartialEq for Token<'_> { + fn eq(&self, other: &Self) -> bool { + self.0.len() == other.0.len() + && self + .0 + .chars() + .zip(other.0.chars()) + .all(|(l, r)| l.to_ascii_uppercase() == r.to_ascii_uppercase()) + } +} + +impl<'a> Eq for Token<'a> {} + +/// The Hash implementation for Token ensures +/// that two tokens, regardless of case, hash to the same +/// value. +impl<'a> Hash for Token<'a> { + fn hash(&self, state: &mut H) { + self.0 + .as_bytes() + .iter() + .map(u8::to_ascii_uppercase) + .for_each(|v| state.write_u8(v)); + } +} + +static KEYWORDS: Lazy>> = Lazy::new(|| { + HashSet::from([ + Token("ALL"), + Token("ALTER"), + Token("ANALYZE"), + Token("AND"), + Token("ANY"), + Token("AS"), + Token("ASC"), + Token("BEGIN"), + Token("BY"), + Token("CARDINALITY"), + Token("CREATE"), + Token("CONTINUOUS"), + Token("DATABASE"), + Token("DATABASES"), + Token("DEFAULT"), + Token("DELETE"), + Token("DESC"), + Token("DESTINATIONS"), + Token("DIAGNOSTICS"), + Token("DISTINCT"), + Token("DROP"), + Token("DURATION"), + Token("END"), + Token("EVERY"), + Token("EXACT"), + Token("EXPLAIN"), + Token("FIELD"), + Token("FOR"), + Token("FROM"), + Token("GRANT"), + Token("GRANTS"), + Token("GROUP"), + Token("GROUPS"), + Token("IN"), + Token("INF"), + Token("INSERT"), + Token("INTO"), + Token("KEY"), + Token("KEYS"), + Token("KILL"), + Token("LIMIT"), + Token("MEASUREMENT"), + Token("MEASUREMENTS"), + Token("NAME"), + Token("OFFSET"), + Token("OR"), + Token("ON"), + Token("ORDER"), + Token("PASSWORD"), + Token("POLICY"), + Token("POLICIES"), + Token("PRIVILEGES"), + Token("QUERIES"), + Token("QUERY"), + Token("READ"), + Token("REPLICATION"), + Token("RESAMPLE"), + Token("RETENTION"), + Token("REVOKE"), + Token("SELECT"), + Token("SERIES"), + Token("SET"), + Token("SHOW"), + Token("SHARD"), + Token("SHARDS"), + Token("SLIMIT"), + Token("SOFFSET"), + Token("STATS"), + Token("SUBSCRIPTION"), + Token("SUBSCRIPTIONS"), + Token("TAG"), + Token("TO"), + Token("USER"), + Token("USERS"), + Token("VALUES"), + Token("WHERE"), + Token("WITH"), + Token("WRITE"), + ]) +}); + +/// Matches any InfluxQL reserved keyword. +pub(crate) fn sql_keyword(i: &str) -> ParseResult<&str, &str> { + verify(terminated(alpha1, keyword_follow_char), |tok: &str| { + KEYWORDS.contains(&Token(tok)) + })(i) +} + +/// Recognizes a case-insensitive `keyword`, ensuring it is followed by +/// a valid separator. +pub(crate) fn keyword<'a>(keyword: &'static str) -> impl FnMut(&'a str) -> ParseResult<&str, &str> { + terminated(tag_no_case(keyword), keyword_follow_char) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::assert_error; + use assert_matches::assert_matches; + + #[test] + fn test_keywords() { + // all keywords + + sql_keyword("ALL").unwrap(); + sql_keyword("ALTER").unwrap(); + sql_keyword("ANALYZE").unwrap(); + sql_keyword("ANY").unwrap(); + sql_keyword("AS").unwrap(); + sql_keyword("ASC").unwrap(); + sql_keyword("BEGIN").unwrap(); + sql_keyword("BY").unwrap(); + sql_keyword("CARDINALITY").unwrap(); + sql_keyword("CREATE").unwrap(); + sql_keyword("CONTINUOUS").unwrap(); + sql_keyword("DATABASE").unwrap(); + sql_keyword("DATABASES").unwrap(); + sql_keyword("DEFAULT").unwrap(); + sql_keyword("DELETE").unwrap(); + sql_keyword("DESC").unwrap(); + sql_keyword("DESTINATIONS").unwrap(); + sql_keyword("DIAGNOSTICS").unwrap(); + sql_keyword("DISTINCT").unwrap(); + sql_keyword("DROP").unwrap(); + sql_keyword("DURATION").unwrap(); + sql_keyword("END").unwrap(); + sql_keyword("EVERY").unwrap(); + sql_keyword("EXACT").unwrap(); + sql_keyword("EXPLAIN").unwrap(); + sql_keyword("FIELD").unwrap(); + sql_keyword("FOR").unwrap(); + sql_keyword("FROM").unwrap(); + sql_keyword("GRANT").unwrap(); + sql_keyword("GRANTS").unwrap(); + sql_keyword("GROUP").unwrap(); + sql_keyword("GROUPS").unwrap(); + sql_keyword("IN").unwrap(); + sql_keyword("INF").unwrap(); + sql_keyword("INSERT").unwrap(); + sql_keyword("INTO").unwrap(); + sql_keyword("KEY").unwrap(); + sql_keyword("KEYS").unwrap(); + sql_keyword("KILL").unwrap(); + sql_keyword("LIMIT").unwrap(); + sql_keyword("MEASUREMENT").unwrap(); + sql_keyword("MEASUREMENTS").unwrap(); + sql_keyword("NAME").unwrap(); + sql_keyword("OFFSET").unwrap(); + sql_keyword("ON").unwrap(); + sql_keyword("ORDER").unwrap(); + sql_keyword("PASSWORD").unwrap(); + sql_keyword("POLICY").unwrap(); + sql_keyword("POLICIES").unwrap(); + sql_keyword("PRIVILEGES").unwrap(); + sql_keyword("QUERIES").unwrap(); + sql_keyword("QUERY").unwrap(); + sql_keyword("READ").unwrap(); + sql_keyword("REPLICATION").unwrap(); + sql_keyword("RESAMPLE").unwrap(); + sql_keyword("RETENTION").unwrap(); + sql_keyword("REVOKE").unwrap(); + sql_keyword("SELECT").unwrap(); + sql_keyword("SERIES").unwrap(); + sql_keyword("SET").unwrap(); + sql_keyword("SHOW").unwrap(); + sql_keyword("SHARD").unwrap(); + sql_keyword("SHARDS").unwrap(); + sql_keyword("SLIMIT").unwrap(); + sql_keyword("SOFFSET").unwrap(); + sql_keyword("STATS").unwrap(); + sql_keyword("SUBSCRIPTION").unwrap(); + sql_keyword("SUBSCRIPTIONS").unwrap(); + sql_keyword("TAG").unwrap(); + sql_keyword("TO").unwrap(); + sql_keyword("USER").unwrap(); + sql_keyword("USERS").unwrap(); + sql_keyword("VALUES").unwrap(); + sql_keyword("WHERE").unwrap(); + sql_keyword("WITH").unwrap(); + sql_keyword("WRITE").unwrap(); + + // case insensitivity + sql_keyword("all").unwrap(); + + // ┌─────────────────────────────┐ + // │ Fallible tests │ + // └─────────────────────────────┘ + + sql_keyword("NOT_A_KEYWORD").unwrap_err(); + } + + #[test] + fn test_keyword() { + // Create a parser for the OR keyword + let mut or_keyword = keyword("OR"); + + // Can parse with matching case + let (rem, got) = or_keyword("OR").unwrap(); + assert_eq!(rem, ""); + assert_eq!(got, "OR"); + + // Not case sensitive + let (rem, got) = or_keyword("or").unwrap(); + assert_eq!(rem, ""); + assert_eq!(got, "or"); + + // Does not consume input that follows a keyword + let (rem, got) = or_keyword("or(a AND b)").unwrap(); + assert_eq!(rem, "(a AND b)"); + assert_eq!(got, "or"); + + // Will fail because keyword `OR` in `ORDER` is not recognized, as is not terminated by a valid character + let err = or_keyword("ORDER").unwrap_err(); + assert_matches!(err, nom::Err::Error(crate::internal::Error::Nom(_, kind)) if kind == nom::error::ErrorKind::Fail); + } + + #[test] + fn test_keyword_followed_by_valid_char() { + let mut tag_keyword = keyword("TAG"); + + // followed by EOF + let (rem, got) = tag_keyword("tag").unwrap(); + assert_eq!(rem, ""); + assert_eq!(got, "tag"); + + // + // Test some of the expected characters + // + + let (rem, got) = tag_keyword("tag!=foo").unwrap(); + assert_eq!(rem, "!=foo"); + assert_eq!(got, "tag"); + + let (rem, got) = tag_keyword("tag>foo").unwrap(); + assert_eq!(rem, ">foo"); + assert_eq!(got, "tag"); + + let (rem, got) = tag_keyword("tag&1 = foo").unwrap(); + assert_eq!(rem, "&1 = foo"); + assert_eq!(got, "tag"); + + // Fallible + + assert_error!(tag_keyword("tag$"), Fail); + } + + #[test] + fn test_token() { + // Are equal with differing case + let (a, b) = (Token("and"), Token("AND")); + assert_eq!(a, b); + + // Are equal with same case + let (a, b) = (Token("and"), Token("and")); + assert_eq!(a, b); + + // a < b + let (a, b) = (Token("and"), Token("apple")); + assert_ne!(a, b); + + // a < b + let (a, b) = (Token("and"), Token("APPLE")); + assert_ne!(a, b); + + // a < b + let (a, b) = (Token("AND"), Token("apple")); + assert_ne!(a, b); + + // a > b + let (a, b) = (Token("and"), Token("aardvark")); + assert_ne!(a, b); + + // a > b + let (a, b) = (Token("and"), Token("AARDVARK")); + assert_ne!(a, b); + + // a > b + let (a, b) = (Token("AND"), Token("aardvark")); + assert_ne!(a, b); + + // Validate prefixes don't match and are correct ordering + + let (a, b) = (Token("aaa"), Token("aaabbb")); + assert_ne!(a, b); + + let (a, b) = (Token("aaabbb"), Token("aaa")); + assert_ne!(a, b); + + let (a, b) = (Token("aaa"), Token("AAABBB")); + assert_ne!(a, b); + + let (a, b) = (Token("AAABBB"), Token("aaa")); + assert_ne!(a, b); + } +} diff --git a/influxdb_influxql_parser/src/lib.rs b/influxdb_influxql_parser/src/lib.rs new file mode 100644 index 0000000..4bb1a60 --- /dev/null +++ b/influxdb_influxql_parser/src/lib.rs @@ -0,0 +1,188 @@ +//! # Parse a subset of [InfluxQL] +//! +//! [InfluxQL]: https://docs.influxdata.com/influxdb/v1.8/query_language + +#![deny(rustdoc::broken_intra_doc_links, rust_2018_idioms)] +#![warn( + missing_copy_implementations, + missing_docs, + clippy::explicit_iter_loop, + // See https://github.com/influxdata/influxdb_iox/pull/1671 + clippy::future_not_send, + clippy::use_self, + clippy::clone_on_ref_ptr, + clippy::todo, + clippy::dbg_macro, + unused_crate_dependencies +)] + +// Workaround for "unused crate" lint false positives. +use workspace_hack as _; + +use crate::common::{statement_terminator, ws0}; +use crate::internal::Error as InternalError; +use crate::statement::{statement, Statement}; +use common::ParseError; +use nom::combinator::eof; +use nom::Offset; + +#[cfg(test)] +mod test_util; + +pub mod common; +pub mod create; +pub mod delete; +pub mod drop; +pub mod explain; +pub mod expression; +pub mod functions; +pub mod identifier; +mod internal; +mod keywords; +pub mod literal; +pub mod parameter; +pub mod select; +pub mod show; +pub mod show_field_keys; +pub mod show_measurements; +pub mod show_retention_policies; +pub mod show_tag_keys; +pub mod show_tag_values; +pub mod simple_from_clause; +pub mod statement; +pub mod string; +pub mod time_range; +pub mod timestamp; +pub mod visit; +pub mod visit_mut; + +/// ParseResult is type that represents the success or failure of parsing +/// a given input into a set of InfluxQL statements. +/// +/// Errors are human-readable messages indicating the cause of the parse failure. +pub type ParseResult = Result, ParseError>; + +/// Parse the input into a set of InfluxQL statements. +pub fn parse_statements(input: &str) -> ParseResult { + let mut res = Vec::new(); + let mut i: &str = input; + + loop { + // Consume whitespace from the input + (i, _) = ws0(i).expect("ws0 is infallible"); + + if eof::<_, nom::error::Error<_>>(i).is_ok() { + return Ok(res); + } + + if let Ok((i1, _)) = statement_terminator(i) { + i = i1; + continue; + } + + match statement(i) { + Ok((i1, o)) => { + res.push(o); + i = i1; + } + Err(nom::Err::Failure(InternalError::Syntax { + input: pos, + message, + })) => { + return Err(ParseError { + message: message.into(), + pos: input.offset(pos), + }) + } + // any other error indicates an invalid statement + Err(_) => { + return Err(ParseError { + message: "invalid SQL statement".into(), + pos: input.offset(i), + }) + } + } + } +} + +#[cfg(test)] +mod test { + use crate::parse_statements; + + /// Validates that the [`parse_statements`] function + /// handles statement terminators and errors. + #[test] + fn test_parse_statements() { + // Parse a single statement, without a terminator + let got = parse_statements("SHOW MEASUREMENTS").unwrap(); + assert_eq!(got.first().unwrap().to_string(), "SHOW MEASUREMENTS"); + + // Parse a single statement, with a terminator + let got = parse_statements("SHOW MEASUREMENTS;").unwrap(); + assert_eq!(got[0].to_string(), "SHOW MEASUREMENTS"); + + // Parse multiple statements with whitespace + let got = parse_statements("SHOW MEASUREMENTS;\nSHOW MEASUREMENTS LIMIT 1").unwrap(); + assert_eq!(got[0].to_string(), "SHOW MEASUREMENTS"); + assert_eq!(got[1].to_string(), "SHOW MEASUREMENTS LIMIT 1"); + + // Parse multiple statements with a terminator in quotes, ensuring it is not interpreted as + // a terminator + let got = + parse_statements("SHOW MEASUREMENTS WITH MEASUREMENT = \";\";SHOW DATABASES").unwrap(); + assert_eq!( + got[0].to_string(), + "SHOW MEASUREMENTS WITH MEASUREMENT = \";\"" + ); + assert_eq!(got[1].to_string(), "SHOW DATABASES"); + + // Parses a statement with a comment + let got = parse_statements( + "SELECT idle FROM cpu WHERE host = 'host1' --GROUP BY host fill(null)", + ) + .unwrap(); + assert_eq!( + got[0].to_string(), + "SELECT idle FROM cpu WHERE host = 'host1'" + ); + + // Parses multiple statements with a comment + let got = parse_statements( + "SELECT idle FROM cpu WHERE host = 'host1' --GROUP BY host fill(null)\nSHOW DATABASES", + ) + .unwrap(); + assert_eq!( + got[0].to_string(), + "SELECT idle FROM cpu WHERE host = 'host1'" + ); + assert_eq!(got[1].to_string(), "SHOW DATABASES"); + + // Parses statement with inline comment + let got = parse_statements(r#"SELECT idle FROM cpu WHERE/* time > now() AND */host = 'host1' --GROUP BY host fill(null)"#).unwrap(); + assert_eq!( + got[0].to_string(), + "SELECT idle FROM cpu WHERE host = 'host1'" + ); + + // Parses empty single-line comments in various placements + let got = parse_statements( + r#"-- foo + -- + -- + SELECT value FROM cpu-- + -- foo + ;SELECT val2 FROM cpu"#, + ) + .unwrap(); + assert_eq!(got[0].to_string(), "SELECT value FROM cpu"); + assert_eq!(got[1].to_string(), "SELECT val2 FROM cpu"); + + // Returns error for invalid statement + let got = parse_statements("BAD SQL").unwrap_err(); + assert_eq!(got.to_string(), "invalid SQL statement at pos 0"); + + // Returns error for invalid statement after first + let got = parse_statements("SHOW MEASUREMENTS;BAD SQL").unwrap_err(); + assert_eq!(got.to_string(), "invalid SQL statement at pos 18"); + } +} diff --git a/influxdb_influxql_parser/src/literal.rs b/influxdb_influxql_parser/src/literal.rs new file mode 100644 index 0000000..3611987 --- /dev/null +++ b/influxdb_influxql_parser/src/literal.rs @@ -0,0 +1,600 @@ +//! Types and parsers for literals. + +use crate::common::ws0; +use crate::internal::{map_error, map_fail, ParseResult}; +use crate::keywords::keyword; +use crate::string::{regex, single_quoted_string, Regex}; +use crate::timestamp::Timestamp; +use crate::{impl_tuple_clause, write_escaped}; +use chrono::{NaiveDateTime, Offset}; +use nom::branch::alt; +use nom::bytes::complete::tag; +use nom::character::complete::{char, digit0, digit1}; +use nom::combinator::{map, opt, recognize, value}; +use nom::multi::fold_many1; +use nom::sequence::{pair, preceded, separated_pair}; +use std::fmt; +use std::fmt::{Display, Formatter, Write}; + +/// Number of nanoseconds in a microsecond. +const NANOS_PER_MICRO: i64 = 1000; +/// Number of nanoseconds in a millisecond. +const NANOS_PER_MILLI: i64 = 1000 * NANOS_PER_MICRO; +/// Number of nanoseconds in a second. +const NANOS_PER_SEC: i64 = 1000 * NANOS_PER_MILLI; +/// Number of nanoseconds in a minute. +const NANOS_PER_MIN: i64 = 60 * NANOS_PER_SEC; +/// Number of nanoseconds in an hour. +const NANOS_PER_HOUR: i64 = 60 * NANOS_PER_MIN; +/// Number of nanoseconds in a day. +const NANOS_PER_DAY: i64 = 24 * NANOS_PER_HOUR; +/// Number of nanoseconds in a week. +const NANOS_PER_WEEK: i64 = 7 * NANOS_PER_DAY; + +/// Primitive InfluxQL literal values, such as strings and regular expressions. +#[derive(Clone, Debug, PartialEq)] +pub enum Literal { + /// Signed integer literal. + Integer(i64), + + /// Unsigned integer literal. + Unsigned(u64), + + /// Float literal. + Float(f64), + + /// Unescaped string literal. + String(String), + + /// Boolean literal. + Boolean(bool), + + /// Duration literal in nanoseconds. + Duration(Duration), + + /// Unescaped regular expression literal. + Regex(Regex), + + /// A timestamp identified in a time range expression of a conditional expression. + Timestamp(Timestamp), +} + +impl From for Literal { + fn from(v: String) -> Self { + Self::String(v) + } +} + +impl From for Literal { + fn from(v: u64) -> Self { + Self::Unsigned(v) + } +} + +impl From for Literal { + fn from(v: i64) -> Self { + Self::Integer(v) + } +} + +impl From for Literal { + fn from(v: f64) -> Self { + Self::Float(v) + } +} + +impl From for Literal { + fn from(v: bool) -> Self { + Self::Boolean(v) + } +} + +impl From for Literal { + fn from(v: Duration) -> Self { + Self::Duration(v) + } +} + +impl From for Literal { + fn from(v: Regex) -> Self { + Self::Regex(v) + } +} + +impl Display for Literal { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Integer(v) => write!(f, "{v}"), + Self::Unsigned(v) => write!(f, "{v}"), + Self::Float(v) => write!(f, "{v}"), + Self::String(v) => { + f.write_char('\'')?; + write_escaped!(f, v, '\n' => "\\n", '\\' => "\\\\", '\'' => "\\'", '"' => "\\\""); + f.write_char('\'') + } + Self::Boolean(v) => write!(f, "{}", if *v { "true" } else { "false" }), + Self::Duration(v) => write!(f, "{v}"), + Self::Regex(v) => write!(f, "{v}"), + Self::Timestamp(ts) => write!(f, "{}", ts.to_rfc3339()), + } + } +} + +/// Parse an InfluxQL integer. +/// +/// InfluxQL defines an integer as follows +/// +/// ```text +/// INTEGER ::= [0-9]+ +/// ``` +fn integer(i: &str) -> ParseResult<&str, i64> { + map_error("unable to parse integer", digit1, &str::parse)(i) +} + +/// Parse an InfluxQL integer to a [`Literal::Integer`] or [`Literal::Unsigned`] +/// if the string overflows. This behavior is consistent with [InfluxQL]. +/// +/// InfluxQL defines an integer as follows +/// +/// ```text +/// INTEGER ::= [0-9]+ +/// ``` +/// +/// [InfluxQL]: https://github.com/influxdata/influxql/blob/7e7d61973256ffeef4b99edd0a89f18a9e52fa2d/parser.go#L2669-L2675 +fn integer_literal(i: &str) -> ParseResult<&str, Literal> { + map_fail( + "unable to parse integer due to overflow", + digit1, + |s: &str| { + s.parse::() + .map(Literal::Integer) + .or_else(|_| s.parse::().map(Literal::Unsigned)) + }, + )(i) +} + +/// Parse an unsigned InfluxQL integer. +/// +/// InfluxQL defines an integer as follows +/// +/// ```text +/// INTEGER ::= [0-9]+ +/// ``` +pub(crate) fn unsigned_integer(i: &str) -> ParseResult<&str, u64> { + map_fail("unable to parse unsigned integer", digit1, &str::parse)(i) +} + +/// Parse an unsigned InfluxQL floating point number. +/// +/// InfluxQL defines a floating point number as follows +/// +/// ```text +/// float ::= INTEGER "." INTEGER +/// INTEGER ::= [0-9]+ +/// ``` +fn float(i: &str) -> ParseResult<&str, f64> { + map_fail( + "unable to parse float", + recognize(separated_pair(digit0, tag("."), digit1)), + &str::parse, + )(i) +} + +/// Represents any signed number. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum Number { + /// Contains a 64-bit integer. + Integer(i64), + /// Contains a 64-bit float. + Float(f64), +} + +impl Display for Number { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Integer(v) => fmt::Display::fmt(v, f), + Self::Float(v) => fmt::Display::fmt(v, f), + } + } +} + +impl From for Number { + fn from(v: f64) -> Self { + Self::Float(v) + } +} + +impl From for Number { + fn from(v: i64) -> Self { + Self::Integer(v) + } +} + +/// Parse a signed [`Number`]. +pub(crate) fn number(i: &str) -> ParseResult<&str, Number> { + let (remaining, sign) = opt(alt((char('-'), char('+'))))(i)?; + preceded( + ws0, + alt(( + map(float, move |v| { + Number::Float(v * if let Some('-') = sign { -1.0 } else { 1.0 }) + }), + map(integer, move |v| { + Number::Integer(v * if let Some('-') = sign { -1 } else { 1 }) + }), + )), + )(remaining) +} + +/// Parse the input for an InfluxQL boolean, which must be the value `true` or `false`. +fn boolean(i: &str) -> ParseResult<&str, bool> { + alt((value(true, keyword("TRUE")), value(false, keyword("FALSE"))))(i) +} + +#[derive(Clone)] +enum DurationUnit { + Nanosecond, + Microsecond, + Millisecond, + Second, + Minute, + Hour, + Day, + Week, +} + +/// Represents an InfluxQL duration in nanoseconds. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Duration(pub(crate) i64); + +impl_tuple_clause!(Duration, i64); + +static DIVISORS: [(i64, &str); 8] = [ + (NANOS_PER_WEEK, "w"), + (NANOS_PER_DAY, "d"), + (NANOS_PER_HOUR, "h"), + (NANOS_PER_MIN, "m"), + (NANOS_PER_SEC, "s"), + (NANOS_PER_MILLI, "ms"), + (NANOS_PER_MICRO, "us"), + (1, "ns"), +]; + +impl Display for Duration { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let v = if self.0.is_negative() { + write!(f, "-")?; + -self.0 + } else { + self.0 + }; + match v { + 0 => f.write_str("0s")?, + mut i => { + // only return the divisors that are > self + for (div, unit) in DIVISORS.iter().filter(|(div, _)| v > *div) { + let units = i / div; + if units > 0 { + write!(f, "{units}{unit}")?; + i -= units * div; + } + } + } + } + + Ok(()) + } +} + +/// Parse the input for a InfluxQL duration fragment and returns the value in nanoseconds. +fn single_duration(i: &str) -> ParseResult<&str, i64> { + use DurationUnit::*; + + map_fail( + "overflow", + pair( + integer, + alt(( + value(Nanosecond, tag("ns")), // nanoseconds + value(Microsecond, tag("µ")), // microseconds + value(Microsecond, tag("u")), // microseconds + value(Millisecond, tag("ms")), // milliseconds + value(Second, tag("s")), // seconds + value(Minute, tag("m")), // minutes + value(Hour, tag("h")), // hours + value(Day, tag("d")), // days + value(Week, tag("w")), // weeks + )), + ), + |(v, unit)| { + (match unit { + Nanosecond => Some(v), + Microsecond => v.checked_mul(NANOS_PER_MICRO), + Millisecond => v.checked_mul(NANOS_PER_MILLI), + Second => v.checked_mul(NANOS_PER_SEC), + Minute => v.checked_mul(NANOS_PER_MIN), + Hour => v.checked_mul(NANOS_PER_HOUR), + Day => v.checked_mul(NANOS_PER_DAY), + Week => v.checked_mul(NANOS_PER_WEEK), + }) + .ok_or("integer overflow") + }, + )(i) +} + +/// Parse the input for an InfluxQL duration. +pub(crate) fn duration(i: &str) -> ParseResult<&str, Duration> { + map( + fold_many1(single_duration, || 0, |acc, fragment| acc + fragment), + Duration, + )(i) +} + +/// Parse an InfluxQL literal, except a [`Regex`]. +/// +/// Use [`literal`] for parsing any literals, excluding regular expressions. +pub(crate) fn literal_no_regex(i: &str) -> ParseResult<&str, Literal> { + alt(( + // NOTE: order is important, as floats should be tested before durations and integers. + map(float, Literal::Float), + map(duration, Literal::Duration), + integer_literal, + map(single_quoted_string, Literal::String), + map(boolean, Literal::Boolean), + ))(i) +} + +/// Parse any InfluxQL literal. +pub(crate) fn literal(i: &str) -> ParseResult<&str, Literal> { + alt((literal_no_regex, map(regex, Literal::Regex)))(i) +} + +/// Parse an InfluxQL literal regular expression. +pub(crate) fn literal_regex(i: &str) -> ParseResult<&str, Literal> { + map(regex, Literal::Regex)(i) +} + +/// Returns `nanos` as a timestamp. +pub fn nanos_to_timestamp(nanos: i64) -> Timestamp { + let (secs, nsec) = num_integer::div_mod_floor(nanos, NANOS_PER_SEC); + + Timestamp::from_naive_utc_and_offset( + NaiveDateTime::from_timestamp_opt(secs, nsec as u32) + .expect("unable to convert duration to timestamp"), + chrono::Utc.fix(), + ) +} + +#[cfg(test)] +mod test { + use super::*; + use assert_matches::assert_matches; + + #[test] + fn test_literal_no_regex() { + // Whole numbers are parsed first as a signed integer, and if that overflows, + // tries an unsigned integer, which is consistent with InfluxQL + let (_, got) = literal_no_regex("42").unwrap(); + assert_matches!(got, Literal::Integer(42)); + + // > i64::MAX + 1 should be parsed as an unsigned integer + let (_, got) = literal_no_regex("9223372036854775808").unwrap(); + assert_matches!(got, Literal::Unsigned(9223372036854775808)); + + let (_, got) = literal_no_regex("42.69").unwrap(); + assert_matches!(got, Literal::Float(v) if v == 42.69); + + let (_, got) = literal_no_regex("'quick draw'").unwrap(); + assert_matches!(got, Literal::String(v) if v == "quick draw"); + + let (_, got) = literal_no_regex("false").unwrap(); + assert_matches!(got, Literal::Boolean(false)); + + let (_, got) = literal_no_regex("true").unwrap(); + assert_matches!(got, Literal::Boolean(true)); + + let (_, got) = literal_no_regex("3h25m").unwrap(); + assert_matches!(got, Literal::Duration(v) if v == Duration(3 * NANOS_PER_HOUR + 25 * NANOS_PER_MIN)); + + // Fallible cases + literal_no_regex("/foo/").unwrap_err(); + } + + #[test] + fn test_literal() { + let (_, got) = literal("/^(match|this)$/").unwrap(); + assert_matches!(got, Literal::Regex(v) if v == "^(match|this)$".into()); + } + + #[test] + fn test_literal_regex() { + let (_, got) = literal_regex("/^(match|this)$/").unwrap(); + assert_matches!(got, Literal::Regex(v) if v == "^(match|this)$".into()); + } + + #[test] + fn test_integer() { + let (_, got) = integer("42").unwrap(); + assert_eq!(got, 42); + + let (_, got) = integer(&i64::MAX.to_string()[..]).unwrap(); + assert_eq!(got, i64::MAX); + + // Fallible cases + + integer("hello").unwrap_err(); + + integer("9223372036854775808").expect_err("expected overflow"); + } + + #[test] + fn test_unsigned_integer() { + let (_, got) = unsigned_integer("42").unwrap(); + assert_eq!(got, 42); + + let (_, got) = unsigned_integer(&u64::MAX.to_string()[..]).unwrap(); + assert_eq!(got, u64::MAX); + + // Fallible cases + + unsigned_integer("hello").unwrap_err(); + } + + #[test] + fn test_float() { + let (_, got) = float("42.69").unwrap(); + assert_eq!(got, 42.69); + + let (_, got) = float(".25").unwrap(); + assert_eq!(got, 0.25); + + let (_, got) = float(&format!("{:.1}", f64::MAX)[..]).unwrap(); + assert_eq!(got, f64::MAX); + + // Fallible cases + + // missing trailing digits + float("41.").unwrap_err(); + + // missing decimal + float("41").unwrap_err(); + } + + #[test] + fn test_boolean() { + let (_, got) = boolean("true").unwrap(); + assert!(got); + let (_, got) = boolean("false").unwrap(); + assert!(!got); + + // Fallible cases + + boolean("truey").unwrap_err(); + boolean("falsey").unwrap_err(); + } + + #[test] + fn test_duration_fragment() { + let (_, got) = single_duration("38ns").unwrap(); + assert_eq!(got, 38); + + let (_, got) = single_duration("22u").unwrap(); + assert_eq!(got, 22 * NANOS_PER_MICRO); + + let (rem, got) = single_duration("22us").unwrap(); + assert_eq!(got, 22 * NANOS_PER_MICRO); + assert_eq!(rem, "s"); // prove that we ignore the trailing s + + let (_, got) = single_duration("7µ").unwrap(); + assert_eq!(got, 7 * NANOS_PER_MICRO); + + let (_, got) = single_duration("15ms").unwrap(); + assert_eq!(got, 15 * NANOS_PER_MILLI); + + let (_, got) = single_duration("53s").unwrap(); + assert_eq!(got, 53 * NANOS_PER_SEC); + + let (_, got) = single_duration("158m").unwrap(); + assert_eq!(got, 158 * NANOS_PER_MIN); + + let (_, got) = single_duration("39h").unwrap(); + assert_eq!(got, 39 * NANOS_PER_HOUR); + + let (_, got) = single_duration("2d").unwrap(); + assert_eq!(got, 2 * NANOS_PER_DAY); + + let (_, got) = single_duration("5w").unwrap(); + assert_eq!(got, 5 * NANOS_PER_WEEK); + + // Fallible + + // Handle overflow + single_duration("16000w").expect_err("expected overflow"); + } + + #[test] + fn test_duration() { + let (_, got) = duration("10h3m2s").unwrap(); + assert_eq!( + got, + Duration(10 * NANOS_PER_HOUR + 3 * NANOS_PER_MIN + 2 * NANOS_PER_SEC) + ); + } + + #[test] + fn test_display_duration() { + let (_, d) = duration("3w2h15ms").unwrap(); + assert_eq!(d.to_string(), "3w2h15ms"); + + let (_, d) = duration("5s5s5s5s5s").unwrap(); + assert_eq!(d.to_string(), "25s"); + + let d = Duration(0); + assert_eq!(d.to_string(), "0s"); + + // Negative duration + let (_, d) = duration("3w2h15ms").unwrap(); + let d = Duration(-d.0); + assert_eq!(d.to_string(), "-3w2h15ms"); + + let d = Duration( + 20 * NANOS_PER_WEEK + + 6 * NANOS_PER_DAY + + 13 * NANOS_PER_HOUR + + 11 * NANOS_PER_MIN + + 10 * NANOS_PER_SEC + + 9 * NANOS_PER_MILLI + + 8 * NANOS_PER_MICRO + + 500, + ); + assert_eq!(d.to_string(), "20w6d13h11m10s9ms8us500ns"); + } + + #[test] + fn test_number() { + // Test floating point numbers + let (_, got) = number("55.3").unwrap(); + assert_matches!(got, Number::Float(v) if v == 55.3); + + let (_, got) = number("-18.9").unwrap(); + assert_matches!(got, Number::Float(v) if v == -18.9); + + let (_, got) = number("- 18.9").unwrap(); + assert_matches!(got, Number::Float(v) if v == -18.9); + + let (_, got) = number("+33.1").unwrap(); + assert_matches!(got, Number::Float(v) if v == 33.1); + + let (_, got) = number("+ 33.1").unwrap(); + assert_matches!(got, Number::Float(v) if v == 33.1); + + // Test integers + let (_, got) = number("42").unwrap(); + assert_matches!(got, Number::Integer(v) if v == 42); + + let (_, got) = number("-32").unwrap(); + assert_matches!(got, Number::Integer(v) if v == -32); + + let (_, got) = number("- 32").unwrap(); + assert_matches!(got, Number::Integer(v) if v == -32); + + let (_, got) = number("+501").unwrap(); + assert_matches!(got, Number::Integer(v) if v == 501); + + let (_, got) = number("+ 501").unwrap(); + assert_matches!(got, Number::Integer(v) if v == 501); + } + + #[test] + fn test_nanos_to_timestamp() { + let ts = nanos_to_timestamp(0); + assert_eq!(ts.to_rfc3339(), "1970-01-01T00:00:00+00:00"); + + // infallible + let ts = nanos_to_timestamp(i64::MAX); + assert_eq!(ts.timestamp_nanos_opt().unwrap(), i64::MAX); + + let ts = nanos_to_timestamp(i64::MIN); + assert_eq!(ts.timestamp_nanos_opt().unwrap(), i64::MIN); + } +} diff --git a/influxdb_influxql_parser/src/parameter.rs b/influxdb_influxql_parser/src/parameter.rs new file mode 100644 index 0000000..5ed28b7 --- /dev/null +++ b/influxdb_influxql_parser/src/parameter.rs @@ -0,0 +1,107 @@ +//! # Parse an InfluxQL [bind parameter] +//! +//! Bind parameters are parsed where a literal value may appear and are prefixed +//! by a `$`. Per the original Go [implementation], the token following the `$` is +//! parsed as an identifier, and therefore may appear in double quotes. +//! +//! [bind parameter]: https://docs.influxdata.com/influxdb/v1.8/tools/api/#bind-parameters +//! [implementation]: https://github.com/influxdata/influxql/blob/df51a45762be9c1b578f01718fa92d286a843fe9/scanner.go#L57-L62 + +use crate::internal::ParseResult; +use crate::string::double_quoted_string; +use crate::{impl_tuple_clause, write_quoted_string}; +use nom::branch::alt; +use nom::bytes::complete::tag; +use nom::character::complete::{alphanumeric1, char}; +use nom::combinator::{map, recognize}; +use nom::multi::many1_count; +use nom::sequence::preceded; +use std::fmt; +use std::fmt::{Display, Formatter, Write}; + +/// Parse an unquoted InfluxQL bind parameter. +fn unquoted_parameter(i: &str) -> ParseResult<&str, &str> { + recognize(many1_count(alt((alphanumeric1, tag("_")))))(i) +} + +/// A type that represents an InfluxQL bind parameter. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub struct BindParameter(pub(crate) String); + +impl_tuple_clause!(BindParameter, String); + +impl From<&str> for BindParameter { + fn from(s: &str) -> Self { + Self(s.to_string()) + } +} + +impl Display for BindParameter { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_char('$')?; + write_quoted_string!(f, '"', self.0.as_str(), unquoted_parameter, '\n' => "\\n", '\\' => "\\\\", '"' => "\\\""); + Ok(()) + } +} + +/// Parses an InfluxQL [BindParameter]. +pub(crate) fn parameter(i: &str) -> ParseResult<&str, BindParameter> { + // See: https://github.com/influxdata/influxql/blob/df51a45762be9c1b578f01718fa92d286a843fe9/scanner.go#L358-L362 + preceded( + char('$'), + alt(( + map(unquoted_parameter, Into::into), + map(double_quoted_string, Into::into), + )), + )(i) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_parameter() { + // all ascii + let (_, got) = parameter("$cpu").unwrap(); + assert_eq!(got, "cpu".into()); + + // digits + let (_, got) = parameter("$01").unwrap(); + assert_eq!(got, "01".into()); + + // all valid chars + let (_, got) = parameter("$cpu_0").unwrap(); + assert_eq!(got, "cpu_0".into()); + + // keyword + let (_, got) = parameter("$from").unwrap(); + assert_eq!(got, "from".into()); + + // quoted + let (_, got) = parameter("$\"quick draw\"").unwrap(); + assert_eq!(got, "quick draw".into()); + + // ┌─────────────────────────────┐ + // │ Fallible tests │ + // └─────────────────────────────┘ + + // missing `$` prefix + parameter("cpu").unwrap_err(); + } + + #[test] + fn test_bind_parameter_display() { + // BindParameter displays quoted output + let got = BindParameter("from foo".into()).to_string(); + assert_eq!(got, r#"$"from foo""#); + + // BindParameter displays quoted and escaped output + let got = BindParameter("from\nfoo".into()).to_string(); + assert_eq!(got, r#"$"from\nfoo""#); + + // BindParameter displays unquoted output + let got = BindParameter("quick_draw".into()).to_string(); + assert_eq!(got, "$quick_draw"); + } +} diff --git a/influxdb_influxql_parser/src/select.rs b/influxdb_influxql_parser/src/select.rs new file mode 100644 index 0000000..f0568c3 --- /dev/null +++ b/influxdb_influxql_parser/src/select.rs @@ -0,0 +1,1404 @@ +//! Types and parsers for the [`SELECT`][sql] statement. +//! +//! [sql]: https://docs.influxdata.com/influxdb/v1.8/query_language/explore-data/#the-basic-select-statement + +use crate::common::{ + limit_clause, offset_clause, order_by_clause, qualified_measurement_name, where_clause, ws0, + ws1, LimitClause, OffsetClause, OrderByClause, ParseError, Parser, QualifiedMeasurementName, + WhereClause, ZeroOrMore, +}; +use crate::expression::arithmetic::Expr::Wildcard; +use crate::expression::arithmetic::{ + arithmetic, call_expression, var_ref, ArithmeticParsers, Expr, WildcardType, +}; +use crate::expression::{Call, VarRef}; +use crate::functions::is_now_function; +use crate::identifier::{identifier, Identifier}; +use crate::impl_tuple_clause; +use crate::internal::{expect, map_fail, verify, ParseResult}; +use crate::keywords::keyword; +use crate::literal::{duration, literal, number, unsigned_integer, Literal, Number}; +use crate::parameter::parameter; +use crate::select::MeasurementSelection::Subquery; +use crate::string::{regex, single_quoted_string, Regex}; +use nom::branch::alt; +use nom::bytes::complete::tag; +use nom::character::complete::char; +use nom::combinator::{map, opt, value}; +use nom::sequence::{delimited, pair, preceded, tuple}; +use nom::Offset; +use std::fmt; +use std::fmt::{Display, Formatter, Write}; +use std::str::FromStr; + +/// Represents a `SELECT` statement. +#[derive(Clone, Debug, PartialEq)] +pub struct SelectStatement { + /// Expressions returned by the selection. + pub fields: FieldList, + + /// A list of measurements or subqueries used as the source data for the selection. + pub from: FromMeasurementClause, + + /// A conditional expression to filter the selection. + pub condition: Option, + + /// Expressions used for grouping the selection. + pub group_by: Option, + + /// The [fill] clause specifies the fill behaviour for the selection. If the value is [`None`], + /// it is the same behavior as `fill(null)`. + /// + /// [fill]: https://docs.influxdata.com/influxdb/v1.8/query_language/explore-data/#group-by-time-intervals-and-fill + pub fill: Option, + + /// Configures the ordering of the selection by time. + pub order_by: Option, + + /// A value to restrict the number of rows returned. + pub limit: Option, + + /// A value to specify an offset to start retrieving rows. + pub offset: Option, + + /// A value to restrict the number of series returned. + pub series_limit: Option, + + /// A value to specify an offset to start retrieving series. + pub series_offset: Option, + + /// The timezone for the query, specified as [`tz('
{ + self.backing + .repositories() + .tables() + .create(name, partition_template, namespace_id) + .await + } + + async fn get_by_id(&mut self, table_id: TableId) -> Result> { + self.backing + .repositories() + .tables() + .get_by_id(table_id) + .await + } + + async fn get_by_namespace_and_name( + &mut self, + namespace_id: NamespaceId, + name: &str, + ) -> Result> { + self.backing + .repositories() + .tables() + .get_by_namespace_and_name(namespace_id, name) + .await + } + + async fn list_by_namespace_id(&mut self, namespace_id: NamespaceId) -> Result> { + self.backing + .repositories() + .tables() + .list_by_namespace_id(namespace_id) + .await + } + + async fn list(&mut self) -> Result> { + self.backing.repositories().tables().list().await + } + + async fn snapshot(&mut self, table_id: TableId) -> Result { + self.backing + .repositories() + .tables() + .snapshot(table_id) + .await + } +} + +#[async_trait] +impl ColumnRepo for Repos { + async fn create_or_get( + &mut self, + name: &str, + table_id: TableId, + column_type: ColumnType, + ) -> Result { + self.backing + .repositories() + .columns() + .create_or_get(name, table_id, column_type) + .await + } + + async fn create_or_get_many_unchecked( + &mut self, + table_id: TableId, + columns: HashMap<&str, ColumnType>, + ) -> Result> { + self.backing + .repositories() + .columns() + .create_or_get_many_unchecked(table_id, columns) + .await + } + + async fn list_by_namespace_id(&mut self, namespace_id: NamespaceId) -> Result> { + self.backing + .repositories() + .columns() + .list_by_namespace_id(namespace_id) + .await + } + + async fn list_by_table_id(&mut self, table_id: TableId) -> Result> { + self.backing + .repositories() + .columns() + .list_by_table_id(table_id) + .await + } + + async fn list(&mut self) -> Result> { + self.backing.repositories().columns().list().await + } +} + +#[async_trait] +impl PartitionRepo for Repos { + async fn create_or_get(&mut self, key: PartitionKey, table_id: TableId) -> Result { + // read-through: need to wire up table snapshots to look this up efficiently + self.backing + .repositories() + .partitions() + .create_or_get(key, table_id) + .await + } + + async fn get_by_id_batch(&mut self, partition_ids: &[PartitionId]) -> Result> { + futures::stream::iter(prepare_set(partition_ids.iter().cloned())) + .map(|p_id| { + let this = &self; + async move { + let snapshot = match this.get_partition(p_id).await { + Ok(s) => s, + Err(Error::NotFound { .. }) => { + return Ok(futures::stream::empty().boxed()); + } + Err(e) => { + return Err(e); + } + }; + + match snapshot.partition() { + Ok(p) => Ok(futures::stream::once(async move { Ok(p) }).boxed()), + Err(e) => Err(Error::from(e)), + } + } + }) + .buffer_unordered(self.quorum_fanout) + .try_flatten() + .try_collect::>() + .await + } + + async fn list_by_table_id(&mut self, table_id: TableId) -> Result> { + // read-through: need to wire up table snapshots to look this up efficiently + self.backing + .repositories() + .partitions() + .list_by_table_id(table_id) + .await + } + + async fn list_ids(&mut self) -> Result> { + // read-through: only used for testing, we should eventually remove this interface + self.backing.repositories().partitions().list_ids().await + } + + async fn cas_sort_key( + &mut self, + partition_id: PartitionId, + old_sort_key_ids: Option<&SortKeyIds>, + new_sort_key_ids: &SortKeyIds, + ) -> Result> { + let res = self + .backing + .repositories() + .partitions() + .cas_sort_key(partition_id, old_sort_key_ids, new_sort_key_ids) + .await?; + + self.refresh_partition(partition_id) + .await + .map_err(CasFailure::QueryError)?; + + Ok(res) + } + + #[allow(clippy::too_many_arguments)] + async fn record_skipped_compaction( + &mut self, + partition_id: PartitionId, + reason: &str, + num_files: usize, + limit_num_files: usize, + limit_num_files_first_in_partition: usize, + estimated_bytes: u64, + limit_bytes: u64, + ) -> Result<()> { + self.backing + .repositories() + .partitions() + .record_skipped_compaction( + partition_id, + reason, + num_files, + limit_num_files, + limit_num_files_first_in_partition, + estimated_bytes, + limit_bytes, + ) + .await?; + + self.refresh_partition(partition_id).await?; + + Ok(()) + } + + async fn get_in_skipped_compactions( + &mut self, + partition_id: &[PartitionId], + ) -> Result> { + futures::stream::iter(prepare_set(partition_id.iter().cloned())) + .map(|p_id| { + let this = &self; + async move { + let snapshot = match this.get_partition(p_id).await { + Ok(s) => s, + Err(Error::NotFound { .. }) => { + return Ok(futures::stream::empty().boxed()); + } + Err(e) => { + return Err(e); + } + }; + + match snapshot.skipped_compaction() { + Some(sc) => Ok(futures::stream::once(async move { Ok(sc) }).boxed()), + None => Ok(futures::stream::empty().boxed()), + } + } + }) + .buffer_unordered(self.quorum_fanout) + .try_flatten() + .try_collect::>() + .await + } + + async fn list_skipped_compactions(&mut self) -> Result> { + // read-through: used for debugging, this should be replaced w/ proper hierarchy-traversal + self.backing + .repositories() + .partitions() + .list_skipped_compactions() + .await + } + + async fn delete_skipped_compactions( + &mut self, + partition_id: PartitionId, + ) -> Result> { + let res = self + .backing + .repositories() + .partitions() + .delete_skipped_compactions(partition_id) + .await?; + + self.refresh_partition(partition_id).await?; + + Ok(res) + } + + async fn most_recent_n(&mut self, n: usize) -> Result> { + // read-through: used for ingester warm-up at the moment + self.backing + .repositories() + .partitions() + .most_recent_n(n) + .await + } + + async fn partitions_new_file_between( + &mut self, + minimum_time: Timestamp, + maximum_time: Option, + ) -> Result> { + // read-through: used by the compactor for scheduling, we should eventually find a better interface + self.backing + .repositories() + .partitions() + .partitions_new_file_between(minimum_time, maximum_time) + .await + } + + async fn list_old_style(&mut self) -> Result> { + // read-through: used by the ingester due to hash-id stuff + self.backing + .repositories() + .partitions() + .list_old_style() + .await + } + + async fn snapshot(&mut self, partition_id: PartitionId) -> Result { + self.get_partition(partition_id).await + } +} + +#[async_trait] +impl ParquetFileRepo for Repos { + async fn flag_for_delete_by_retention(&mut self) -> Result> { + let res = self + .backing + .repositories() + .parquet_files() + .flag_for_delete_by_retention() + .await?; + + let affected_partitions = res + .iter() + .map(|(p_id, _os_id)| *p_id) + .collect::>(); + + // ensure deterministic order + let mut affected_partitions = affected_partitions.into_iter().collect::>(); + affected_partitions.sort_unstable(); + + // refresh ALL partitons that are affected, NOT just only the ones that were cached. This should avoid the + // following "lost update" race condition: + // + // This scenario assumes that the partition in question is NOT cached yet. + // + // | T | Thread 1 | Thread 2 | + // | - | ------------------------------------- | -------------------------------------------------- | + // | 1 | receive `create_update_delete` | | + // | 2 | execute change within backing catalog | | + // | 3 | takes snapshot from backing catalog | | + // | 4 | | receive `flag_for_delete_by_retention` | + // | 5 | | execute change within backing catalog | + // | 6 | | affected partition not cached => no snapshot taken | + // | 7 | | return | + // | 8 | quorum-write snapshot | | + // | 9 | return | | + // + // The partition is now cached by does NOT contain the `flag_for_delete_by_retention` change and will not + // automatically converge. + futures::stream::iter(affected_partitions) + .map(|p_id| { + let this = &self; + async move { + this.refresh_partition(p_id).await?; + Ok::<(), Error>(()) + } + }) + .buffer_unordered(self.quorum_fanout) + .try_collect::<()>() + .await?; + + Ok(res) + } + + async fn delete_old_ids_only(&mut self, older_than: Timestamp) -> Result> { + // deleted files are NOT part of the snapshot, so this bypasses the cache + self.backing + .repositories() + .parquet_files() + .delete_old_ids_only(older_than) + .await + } + + async fn list_by_partition_not_to_delete_batch( + &mut self, + partition_ids: Vec, + ) -> Result> { + futures::stream::iter(prepare_set(partition_ids)) + .map(|p_id| { + let this = &self; + async move { + let snapshot = match this.get_partition(p_id).await { + Ok(s) => s, + Err(Error::NotFound { .. }) => { + return Ok(futures::stream::empty().boxed()); + } + Err(e) => { + return Err(e); + } + }; + + // Decode files so we can drop the snapshot early. + // + // Need to collect the file results into a vec though because we cannot return borrowed data and + // "owned iterators" aren't a thing. + let files = snapshot + .files() + .map(|res| res.map_err(Error::from)) + .collect::>(); + Ok::<_, Error>(futures::stream::iter(files).boxed()) + } + }) + .buffer_unordered(self.quorum_fanout) + .try_flatten() + .try_collect::>() + .await + } + + async fn get_by_object_store_id( + &mut self, + object_store_id: ObjectStoreId, + ) -> Result> { + // read-through: see https://github.com/influxdata/influxdb_iox/issues/9719 + self.backing + .repositories() + .parquet_files() + .get_by_object_store_id(object_store_id) + .await + } + + async fn exists_by_object_store_id_batch( + &mut self, + object_store_ids: Vec, + ) -> Result> { + // read-through: this is used by the GC, so this is not overall latency-critical + self.backing + .repositories() + .parquet_files() + .exists_by_object_store_id_batch(object_store_ids) + .await + } + + async fn create_upgrade_delete( + &mut self, + partition_id: PartitionId, + delete: &[ObjectStoreId], + upgrade: &[ObjectStoreId], + create: &[ParquetFileParams], + target_level: CompactionLevel, + ) -> Result> { + let res = self + .backing + .repositories() + .parquet_files() + .create_upgrade_delete(partition_id, delete, upgrade, create, target_level) + .await?; + + self.refresh_partition(partition_id).await?; + + Ok(res) + } +} + +/// Prepare set of elements in deterministic order. +fn prepare_set(set: S) -> Vec +where + S: IntoIterator, + T: Eq + Ord, +{ + // ensure deterministic order (also required for de-dup) + let mut set = set.into_iter().collect::>(); + set.sort_unstable(); + + // de-dup + set.dedup(); + + set +} + +#[cfg(test)] +mod tests { + use catalog_cache::api::server::test_util::TestCacheServer; + use catalog_cache::local::CatalogCache; + use iox_time::SystemProvider; + + use crate::{interface_tests::TestCatalog, mem::MemCatalog}; + + use super::*; + use std::sync::Arc; + + #[tokio::test] + async fn test_catalog() { + crate::interface_tests::test_catalog(|| async { + let metrics = Arc::new(metric::Registry::default()); + let time_provider = Arc::new(SystemProvider::new()) as _; + let backing = Arc::new(MemCatalog::new(metrics, Arc::clone(&time_provider))); + + let peer0 = TestCacheServer::bind_ephemeral(); + let peer1 = TestCacheServer::bind_ephemeral(); + let cache = Arc::new(QuorumCatalogCache::new( + Arc::new(CatalogCache::default()), + Arc::new([peer0.client(), peer1.client()]), + )); + + // use new metrics registry so the two layers don't double-count + let metrics = Arc::new(metric::Registry::default()); + let caching_catalog = Arc::new(CachingCatalog::new( + cache, + backing, + metrics, + time_provider, + 10, + )); + + let test_catalog = TestCatalog::new(caching_catalog); + test_catalog.hold_onto(peer0); + test_catalog.hold_onto(peer1); + + Arc::new(test_catalog) as _ + }) + .await; + } +} diff --git a/iox_catalog/src/constants.rs b/iox_catalog/src/constants.rs new file mode 100644 index 0000000..b6b88fb --- /dev/null +++ b/iox_catalog/src/constants.rs @@ -0,0 +1,19 @@ +//! Constants that are hold for all catalog implementations. + +/// Time column. +pub const TIME_COLUMN: &str = "time"; + +/// Default retention period for data in the catalog. +pub const DEFAULT_RETENTION_PERIOD: Option = None; + +/// Maximum number of files touched by [`ParquetFileRepo::flag_for_delete_by_retention`] at a time. +/// +/// +/// [`ParquetFileRepo::flag_for_delete_by_retention`]: crate::interface::ParquetFileRepo::flag_for_delete_by_retention +pub const MAX_PARQUET_FILES_SELECTED_ONCE_FOR_RETENTION: i64 = 1_000; + +/// Maximum number of files touched by [`ParquetFileRepo::delete_old_ids_only`] at a time. +/// +/// +/// [`ParquetFileRepo::delete_old_ids_only`]: crate::interface::ParquetFileRepo::delete_old_ids_only +pub const MAX_PARQUET_FILES_SELECTED_ONCE_FOR_DELETE: i64 = 10_000; diff --git a/iox_catalog/src/grpc/client.rs b/iox_catalog/src/grpc/client.rs new file mode 100644 index 0000000..8edc05d --- /dev/null +++ b/iox_catalog/src/grpc/client.rs @@ -0,0 +1,997 @@ +//! gRPC client implementation. +use std::future::Future; +use std::ops::ControlFlow; +use std::{collections::HashMap, sync::Arc}; + +use async_trait::async_trait; +use futures::TryStreamExt; +use log::{debug, info, warn}; +use tonic::transport::{Channel, Uri}; + +use crate::{ + interface::{ + CasFailure, Catalog, ColumnRepo, Error, NamespaceRepo, ParquetFileRepo, PartitionRepo, + RepoCollection, Result, SoftDeletedRows, TableRepo, + }, + metrics::MetricDecorator, +}; +use backoff::{Backoff, BackoffError}; +use data_types::snapshot::partition::PartitionSnapshot; +use data_types::{ + partition_template::{NamespacePartitionTemplateOverride, TablePartitionTemplateOverride}, + snapshot::table::TableSnapshot, + Column, ColumnType, CompactionLevel, MaxColumnsPerTable, MaxTables, Namespace, NamespaceId, + NamespaceName, NamespaceServiceProtectionLimitsOverride, ObjectStoreId, ParquetFile, + ParquetFileId, ParquetFileParams, Partition, PartitionId, PartitionKey, SkippedCompaction, + SortKeyIds, Table, TableId, Timestamp, +}; +use generated_types::influxdata::iox::catalog::v2 as proto; +use iox_time::TimeProvider; +use trace_http::metrics::{MetricFamily, RequestMetrics}; +use trace_http::tower::TraceService; + +use super::serialization::{ + convert_status, deserialize_column, deserialize_namespace, deserialize_object_store_id, + deserialize_parquet_file, deserialize_partition, deserialize_skipped_compaction, + deserialize_sort_key_ids, deserialize_table, serialize_column_type, serialize_object_store_id, + serialize_parquet_file_params, serialize_soft_deleted_rows, serialize_sort_key_ids, ContextExt, + RequiredExt, +}; + +type InstrumentedChannel = TraceService; + +/// Catalog that goes through a gRPC interface. +#[derive(Debug)] +pub struct GrpcCatalogClient { + channel: InstrumentedChannel, + metrics: Arc, + time_provider: Arc, +} + +impl GrpcCatalogClient { + /// Create new client. + pub fn new( + uri: Uri, + metrics: Arc, + time_provider: Arc, + ) -> Self { + let channel = TraceService::new_client( + Channel::builder(uri).connect_lazy(), + Arc::new(RequestMetrics::new( + Arc::clone(&metrics), + MetricFamily::GrpcClient, + )), + None, + "catalog", + ); + Self { + channel, + metrics, + time_provider, + } + } +} + +impl std::fmt::Display for GrpcCatalogClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "grpc") + } +} + +#[async_trait] +impl Catalog for GrpcCatalogClient { + async fn setup(&self) -> Result<(), Error> { + Ok(()) + } + + fn repositories(&self) -> Box { + Box::new(MetricDecorator::new( + GrpcCatalogClientRepos { + channel: self.channel.clone(), + }, + Arc::clone(&self.metrics), + Arc::clone(&self.time_provider), + )) + } + + #[cfg(test)] + fn metrics(&self) -> Arc { + Arc::clone(&self.metrics) + } + + fn time_provider(&self) -> Arc { + Arc::clone(&self.time_provider) + } +} + +#[derive(Debug)] +struct GrpcCatalogClientRepos { + channel: InstrumentedChannel, +} + +type ServiceClient = proto::catalog_service_client::CatalogServiceClient; + +fn is_upstream_error(e: &tonic::Status) -> bool { + matches!( + e.code(), + tonic::Code::Cancelled + | tonic::Code::DeadlineExceeded + | tonic::Code::FailedPrecondition + | tonic::Code::Aborted + | tonic::Code::Unavailable + ) +} + +impl GrpcCatalogClientRepos { + fn client(&self) -> ServiceClient { + proto::catalog_service_client::CatalogServiceClient::new(self.channel.clone()) + } + + async fn retry( + &self, + operation: &str, + upload: U, + fun_io: FunIo, + ) -> Result + where + U: Clone + std::fmt::Debug + Send + Sync, + FunIo: Fn(U, ServiceClient) -> Fut + Send + Sync, + Fut: Future, tonic::Status>> + Send, + D: std::fmt::Debug, + { + Backoff::new(&Default::default()) + .retry_with_backoff(operation, || async { + let res = fun_io(upload.clone(), self.client()).await; + match res { + Ok(r) => { + let r = r.into_inner(); + debug!("{} successfully received: {:?}", operation, &r); + ControlFlow::Break(Ok(r)) + } + Err(e) if is_upstream_error(&e) => { + info!("{} retriable error encountered: {:?}", operation, &e); + ControlFlow::Continue(e) + } + Err(e) => { + warn!( + "{operation} attempted {:?} and received error: {:?}", + upload, e + ); + ControlFlow::Break(Err(convert_status(e))) + } + } + }) + .await + .map_err(|be| { + let status = match be { + BackoffError::DeadlineExceeded { source, .. } => source, + }; + convert_status(status) + })? + } +} + +impl RepoCollection for GrpcCatalogClientRepos { + fn namespaces(&mut self) -> &mut dyn NamespaceRepo { + self + } + + fn tables(&mut self) -> &mut dyn TableRepo { + self + } + + fn columns(&mut self) -> &mut dyn ColumnRepo { + self + } + + fn partitions(&mut self) -> &mut dyn PartitionRepo { + self + } + + fn parquet_files(&mut self) -> &mut dyn ParquetFileRepo { + self + } +} + +#[async_trait] +impl NamespaceRepo for GrpcCatalogClientRepos { + async fn create( + &mut self, + name: &NamespaceName<'_>, + partition_template: Option, + retention_period_ns: Option, + service_protection_limits: Option, + ) -> Result { + let n = proto::NamespaceCreateRequest { + name: name.to_string(), + partition_template: partition_template.and_then(|t| t.as_proto().cloned()), + retention_period_ns, + service_protection_limits: service_protection_limits.map(|l| { + proto::ServiceProtectionLimits { + max_tables: l.max_tables.map(|x| x.get_i32()), + max_columns_per_table: l.max_columns_per_table.map(|x| x.get_i32()), + } + }), + }; + + let resp = self + .retry("namespace_create", n, |data, mut client| async move { + client.namespace_create(data).await + }) + .await?; + + Ok(deserialize_namespace( + resp.namespace.required().ctx("namespace")?, + )?) + } + + async fn update_retention_period( + &mut self, + name: &str, + retention_period_ns: Option, + ) -> Result { + let n = proto::NamespaceUpdateRetentionPeriodRequest { + name: name.to_owned(), + retention_period_ns, + }; + + let resp = self.retry( + "namespace_update_retention_period", + n, + |data, mut client| async move { client.namespace_update_retention_period(data).await }, + ) + .await?; + + Ok(deserialize_namespace( + resp.namespace.required().ctx("namespace")?, + )?) + } + + async fn list(&mut self, deleted: SoftDeletedRows) -> Result> { + let n = proto::NamespaceListRequest { + deleted: serialize_soft_deleted_rows(deleted), + }; + + self.retry("namespace_list", n, |data, mut client| async move { + client.namespace_list(data).await + }) + .await? + .map_err(convert_status) + .and_then(|res| async move { + deserialize_namespace(res.namespace.required().ctx("namespace")?).map_err(Error::from) + }) + .try_collect() + .await + } + + async fn get_by_id( + &mut self, + id: NamespaceId, + deleted: SoftDeletedRows, + ) -> Result> { + let n = proto::NamespaceGetByIdRequest { + id: id.get(), + deleted: serialize_soft_deleted_rows(deleted), + }; + + let resp = self + .retry("namespace_get_by_id", n, |data, mut client| async move { + client.namespace_get_by_id(data).await + }) + .await?; + Ok(resp.namespace.map(deserialize_namespace).transpose()?) + } + + async fn get_by_name( + &mut self, + name: &str, + deleted: SoftDeletedRows, + ) -> Result> { + let n = proto::NamespaceGetByNameRequest { + name: name.to_owned(), + deleted: serialize_soft_deleted_rows(deleted), + }; + + let resp = self + .retry("namespace_get_by_name", n, |data, mut client| async move { + client.namespace_get_by_name(data).await + }) + .await?; + Ok(resp.namespace.map(deserialize_namespace).transpose()?) + } + + async fn soft_delete(&mut self, name: &str) -> Result<()> { + let n = proto::NamespaceSoftDeleteRequest { + name: name.to_owned(), + }; + + self.retry("namespace_soft_delete", n, |data, mut client| async move { + client.namespace_soft_delete(data).await + }) + .await?; + Ok(()) + } + + async fn update_table_limit(&mut self, name: &str, new_max: MaxTables) -> Result { + let n = proto::NamespaceUpdateTableLimitRequest { + name: name.to_owned(), + new_max: new_max.get_i32(), + }; + + let resp = self + .retry("namespace_soft_delete", n, |data, mut client| async move { + client.namespace_update_table_limit(data).await + }) + .await?; + + Ok(deserialize_namespace( + resp.namespace.required().ctx("namespace")?, + )?) + } + + async fn update_column_limit( + &mut self, + name: &str, + new_max: MaxColumnsPerTable, + ) -> Result { + let n = proto::NamespaceUpdateColumnLimitRequest { + name: name.to_owned(), + new_max: new_max.get_i32(), + }; + + let resp = self + .retry("namespace_soft_delete", n, |data, mut client| async move { + client.namespace_update_column_limit(data).await + }) + .await?; + + Ok(deserialize_namespace( + resp.namespace.required().ctx("namespace")?, + )?) + } +} + +#[async_trait] +impl TableRepo for GrpcCatalogClientRepos { + async fn create( + &mut self, + name: &str, + partition_template: TablePartitionTemplateOverride, + namespace_id: NamespaceId, + ) -> Result
{ + let t = proto::TableCreateRequest { + name: name.to_owned(), + partition_template: partition_template.as_proto().cloned(), + namespace_id: namespace_id.get(), + }; + + let resp = self + .retry("table_create", t, |data, mut client| async move { + client.table_create(data).await + }) + .await?; + Ok(deserialize_table(resp.table.required().ctx("table")?)?) + } + + async fn get_by_id(&mut self, table_id: TableId) -> Result> { + let t = proto::TableGetByIdRequest { id: table_id.get() }; + + let resp = self + .retry("table_get_by_id", t, |data, mut client| async move { + client.table_get_by_id(data).await + }) + .await?; + Ok(resp.table.map(deserialize_table).transpose()?) + } + + async fn get_by_namespace_and_name( + &mut self, + namespace_id: NamespaceId, + name: &str, + ) -> Result> { + let t = proto::TableGetByNamespaceAndNameRequest { + namespace_id: namespace_id.get(), + name: name.to_owned(), + }; + + let resp = self.retry( + "table_get_by_namespace_and_name", + t, + |data, mut client| async move { client.table_get_by_namespace_and_name(data).await }, + ) + .await?; + Ok(resp.table.map(deserialize_table).transpose()?) + } + + async fn list_by_namespace_id(&mut self, namespace_id: NamespaceId) -> Result> { + let t = proto::TableListByNamespaceIdRequest { + namespace_id: namespace_id.get(), + }; + + self.retry( + "table_list_by_namespace_id", + t, + |data, mut client| async move { client.table_list_by_namespace_id(data).await }, + ) + .await? + .map_err(convert_status) + .and_then(|res| async move { Ok(deserialize_table(res.table.required().ctx("table")?)?) }) + .try_collect() + .await + } + + async fn list(&mut self) -> Result> { + let t = proto::TableListRequest {}; + + self.retry("table_list", t, |data, mut client| async move { + client.table_list(data).await + }) + .await? + .map_err(convert_status) + .and_then(|res| async move { Ok(deserialize_table(res.table.required().ctx("table")?)?) }) + .try_collect() + .await + } + + async fn snapshot(&mut self, table_id: TableId) -> Result { + let t = proto::TableSnapshotRequest { + table_id: table_id.get(), + }; + + let resp = self + .retry("table_snapshot", t, |data, mut client| async move { + client.table_snapshot(data).await + }) + .await?; + + let table = resp.table.required().ctx("table")?; + Ok(TableSnapshot::decode(table, resp.generation)) + } +} + +#[async_trait] +impl ColumnRepo for GrpcCatalogClientRepos { + async fn create_or_get( + &mut self, + name: &str, + table_id: TableId, + column_type: ColumnType, + ) -> Result { + let c = proto::ColumnCreateOrGetRequest { + name: name.to_owned(), + table_id: table_id.get(), + column_type: serialize_column_type(column_type), + }; + + let resp = self + .retry("column_create_or_get", c, |data, mut client| async move { + client.column_create_or_get(data).await + }) + .await?; + Ok(deserialize_column(resp.column.required().ctx("column")?)?) + } + + async fn create_or_get_many_unchecked( + &mut self, + table_id: TableId, + columns: HashMap<&str, ColumnType>, + ) -> Result> { + let c = proto::ColumnCreateOrGetManyUncheckedRequest { + table_id: table_id.get(), + columns: columns + .into_iter() + .map(|(name, t)| (name.to_owned(), serialize_column_type(t))) + .collect(), + }; + + self.retry( + "column_create_or_get_many_unchecked", + c, + |data, mut client| async move { client.column_create_or_get_many_unchecked(data).await }, + ) + .await? + .map_err(convert_status) + .and_then(|res| async move { + Ok(deserialize_column(res.column.required().ctx("column")?)?) + }) + .try_collect() + .await + } + + async fn list_by_namespace_id(&mut self, namespace_id: NamespaceId) -> Result> { + let c = proto::ColumnListByNamespaceIdRequest { + namespace_id: namespace_id.get(), + }; + + self.retry( + "column_list_by_namespace_id", + c, + |data, mut client| async move { client.column_list_by_namespace_id(data).await }, + ) + .await? + .map_err(convert_status) + .and_then( + |res| async move { Ok(deserialize_column(res.column.required().ctx("column")?)?) }, + ) + .try_collect() + .await + } + + async fn list_by_table_id(&mut self, table_id: TableId) -> Result> { + let c = proto::ColumnListByTableIdRequest { + table_id: table_id.get(), + }; + + self.retry( + "column_list_by_table_id", + c, + |data, mut client| async move { client.column_list_by_table_id(data).await }, + ) + .await? + .map_err(convert_status) + .and_then( + |res| async move { Ok(deserialize_column(res.column.required().ctx("column")?)?) }, + ) + .try_collect() + .await + } + + async fn list(&mut self) -> Result> { + let c = proto::ColumnListRequest {}; + + self.retry("column_list", c, |data, mut client| async move { + client.column_list(data).await + }) + .await? + .map_err(convert_status) + .and_then( + |res| async move { Ok(deserialize_column(res.column.required().ctx("column")?)?) }, + ) + .try_collect() + .await + } +} + +#[async_trait] +impl PartitionRepo for GrpcCatalogClientRepos { + async fn create_or_get(&mut self, key: PartitionKey, table_id: TableId) -> Result { + let p = proto::PartitionCreateOrGetRequest { + key: key.inner().to_owned(), + table_id: table_id.get(), + }; + + let resp = self + .retry( + "partition_create_or_get", + p, + |data, mut client| async move { client.partition_create_or_get(data).await }, + ) + .await?; + + Ok(deserialize_partition( + resp.partition.required().ctx("partition")?, + )?) + } + + async fn get_by_id_batch(&mut self, partition_ids: &[PartitionId]) -> Result> { + let p = proto::PartitionGetByIdBatchRequest { + partition_ids: partition_ids.iter().map(|id| id.get()).collect(), + }; + + self.retry( + "partition_get_by_id_batch", + p, + |data, mut client| async move { client.partition_get_by_id_batch(data).await }, + ) + .await? + .map_err(convert_status) + .and_then(|res| async move { + Ok(deserialize_partition( + res.partition.required().ctx("partition")?, + )?) + }) + .try_collect() + .await + } + + async fn list_by_table_id(&mut self, table_id: TableId) -> Result> { + let p = proto::PartitionListByTableIdRequest { + table_id: table_id.get(), + }; + + self.retry( + "partition_list_by_table_id", + p, + |data, mut client| async move { client.partition_list_by_table_id(data).await }, + ) + .await? + .map_err(convert_status) + .and_then(|res| async move { + Ok(deserialize_partition( + res.partition.required().ctx("partition")?, + )?) + }) + .try_collect() + .await + } + + async fn list_ids(&mut self) -> Result> { + let p = proto::PartitionListIdsRequest {}; + + self.retry("partition_list_ids", p, |data, mut client| async move { + client.partition_list_ids(data).await + }) + .await? + .map_err(convert_status) + .map_ok(|res| PartitionId::new(res.partition_id)) + .try_collect() + .await + } + + async fn cas_sort_key( + &mut self, + partition_id: PartitionId, + old_sort_key_ids: Option<&SortKeyIds>, + new_sort_key_ids: &SortKeyIds, + ) -> Result> { + // This method does not use request/request_streaming_response + // because the error handling (converting to CasFailure) differs + // from how all the other methods handle errors. + + let p = proto::PartitionCasSortKeyRequest { + partition_id: partition_id.get(), + old_sort_key_ids: old_sort_key_ids.map(serialize_sort_key_ids), + new_sort_key_ids: Some(serialize_sort_key_ids(new_sort_key_ids)), + }; + + let res = self + .retry("partition_cas_sort_key", p, |data, mut client| async move { + client.partition_cas_sort_key(data).await + }) + .await + .map_err(CasFailure::QueryError)?; + + let res = res + .res + .required() + .ctx("res") + .map_err(|e| CasFailure::QueryError(e.into()))?; + + match res { + proto::partition_cas_sort_key_response::Res::Partition(p) => { + let p = deserialize_partition(p).map_err(|e| CasFailure::QueryError(e.into()))?; + Ok(p) + } + proto::partition_cas_sort_key_response::Res::CurrentSortKey(k) => { + Err(CasFailure::ValueMismatch(deserialize_sort_key_ids(k))) + } + } + } + + #[allow(clippy::too_many_arguments)] + async fn record_skipped_compaction( + &mut self, + partition_id: PartitionId, + reason: &str, + num_files: usize, + limit_num_files: usize, + limit_num_files_first_in_partition: usize, + estimated_bytes: u64, + limit_bytes: u64, + ) -> Result<()> { + let p = proto::PartitionRecordSkippedCompactionRequest { + partition_id: partition_id.get(), + reason: reason.to_owned(), + num_files: num_files as u64, + limit_num_files: limit_num_files as u64, + limit_num_files_first_in_partition: limit_num_files_first_in_partition as u64, + estimated_bytes, + limit_bytes, + }; + + self.retry( + "partition_record_skipped_compaction", + p, + |data, mut client| async move { client.partition_record_skipped_compaction(data).await }, + ) + .await?; + Ok(()) + } + + async fn get_in_skipped_compactions( + &mut self, + partition_id: &[PartitionId], + ) -> Result> { + let p = proto::PartitionGetInSkippedCompactionsRequest { + partition_ids: partition_id.iter().map(|id| id.get()).collect(), + }; + + self.retry( + "partition_get_in_skipped_compactions", + p, + |data, mut client| async move { client.partition_get_in_skipped_compactions(data).await }, + ) + .await? + .map_err(convert_status) + .and_then(|res| async move { + Ok(deserialize_skipped_compaction(res.skipped_compaction.required().ctx("skipped_compaction")?)) + }) + .try_collect() + .await + } + + async fn list_skipped_compactions(&mut self) -> Result> { + let p = proto::PartitionListSkippedCompactionsRequest {}; + + self.retry( + "partition_list_skipped_compactions", + p, + |data, mut client| async move { client.partition_list_skipped_compactions(data).await }, + ) + .await? + .map_err(convert_status) + .and_then(|res| async move { + Ok(deserialize_skipped_compaction( + res.skipped_compaction + .required() + .ctx("skipped_compaction")?, + )) + }) + .try_collect() + .await + } + + async fn delete_skipped_compactions( + &mut self, + partition_id: PartitionId, + ) -> Result> { + let p = proto::PartitionDeleteSkippedCompactionsRequest { + partition_id: partition_id.get(), + }; + + let resp = self + .retry( + "partition_delete_skipped_compactions", + p, + |data, mut client| async move { + client.partition_delete_skipped_compactions(data).await + }, + ) + .await?; + + Ok(resp.skipped_compaction.map(deserialize_skipped_compaction)) + } + + async fn most_recent_n(&mut self, n: usize) -> Result> { + let p = proto::PartitionMostRecentNRequest { n: n as u64 }; + + self.retry( + "partition_most_recent_n", + p, + |data, mut client| async move { client.partition_most_recent_n(data).await }, + ) + .await? + .map_err(convert_status) + .and_then(|res| async move { + Ok(deserialize_partition( + res.partition.required().ctx("partition")?, + )?) + }) + .try_collect() + .await + } + + async fn partitions_new_file_between( + &mut self, + minimum_time: Timestamp, + maximum_time: Option, + ) -> Result> { + let p = proto::PartitionNewFileBetweenRequest { + minimum_time: minimum_time.get(), + maximum_time: maximum_time.map(|ts| ts.get()), + }; + + self.retry( + "partition_new_file_between", + p, + |data, mut client| async move { client.partition_new_file_between(data).await }, + ) + .await? + .map_err(convert_status) + .map_ok(|res| PartitionId::new(res.partition_id)) + .try_collect() + .await + } + + async fn list_old_style(&mut self) -> Result> { + let p = proto::PartitionListOldStyleRequest {}; + + self.retry( + "partition_list_old_style", + p, + |data, mut client| async move { client.partition_list_old_style(data).await }, + ) + .await? + .map_err(convert_status) + .and_then(|res| async move { + Ok(deserialize_partition( + res.partition.required().ctx("partition")?, + )?) + }) + .try_collect() + .await + } + + async fn snapshot(&mut self, partition_id: PartitionId) -> Result { + let p = proto::PartitionSnapshotRequest { + partition_id: partition_id.get(), + }; + + let resp = self + .retry("partition_snapshot", p, |data, mut client| async move { + client.partition_snapshot(data).await + }) + .await?; + let partition = resp.partition.required().ctx("partition")?; + Ok(PartitionSnapshot::decode(partition, resp.generation)) + } +} + +#[async_trait] +impl ParquetFileRepo for GrpcCatalogClientRepos { + async fn flag_for_delete_by_retention(&mut self) -> Result> { + let p = proto::ParquetFileFlagForDeleteByRetentionRequest {}; + + self.retry( + "parquet_file_flag_for_delete_by_retention", + p, + |data, mut client| async move { + client.parquet_file_flag_for_delete_by_retention(data).await + }, + ) + .await? + .map_err(convert_status) + .and_then(|res| async move { + Ok(( + PartitionId::new(res.partition_id), + deserialize_object_store_id(res.object_store_id.required().ctx("object_store_id")?), + )) + }) + .try_collect() + .await + } + + async fn delete_old_ids_only(&mut self, older_than: Timestamp) -> Result> { + let p = proto::ParquetFileDeleteOldIdsOnlyRequest { + older_than: older_than.get(), + }; + + self.retry( + "parquet_file_delete_old_ids_only", + p, + |data, mut client| async move { client.parquet_file_delete_old_ids_only(data).await }, + ) + .await? + .map_err(convert_status) + .and_then(|res| async move { + Ok(deserialize_object_store_id( + res.object_store_id.required().ctx("object_store_id")?, + )) + }) + .try_collect() + .await + } + + async fn list_by_partition_not_to_delete_batch( + &mut self, + partition_ids: Vec, + ) -> Result> { + let p = proto::ParquetFileListByPartitionNotToDeleteBatchRequest { + partition_ids: partition_ids.into_iter().map(|p| p.get()).collect(), + }; + + self.retry( + "parquet_file_list_by_partition_not_to_delete_batch", + p, + |data, mut client| async move { + client + .parquet_file_list_by_partition_not_to_delete_batch(data) + .await + }, + ) + .await? + .map_err(convert_status) + .and_then(|res| async move { + Ok(deserialize_parquet_file( + res.parquet_file.required().ctx("parquet_file")?, + )?) + }) + .try_collect() + .await + } + + async fn get_by_object_store_id( + &mut self, + object_store_id: ObjectStoreId, + ) -> Result> { + let p = proto::ParquetFileGetByObjectStoreIdRequest { + object_store_id: Some(serialize_object_store_id(object_store_id)), + }; + + let maybe_file = self.retry( + "parquet_file_get_by_object_store_id", + p, + |data, mut client| async move { client.parquet_file_get_by_object_store_id(data).await }) + .await? + .parquet_file.map(deserialize_parquet_file).transpose()?; + Ok(maybe_file) + } + + async fn exists_by_object_store_id_batch( + &mut self, + object_store_ids: Vec, + ) -> Result> { + let p = futures::stream::iter(object_store_ids.into_iter().map(|id| { + proto::ParquetFileExistsByObjectStoreIdBatchRequest { + object_store_id: Some(serialize_object_store_id(id)), + } + })); + + self.retry( + "parquet_file_exists_by_object_store_id_batch", + p, + |data, mut client: ServiceClient| async move { + client + .parquet_file_exists_by_object_store_id_batch(data) + .await + }, + ) + .await? + .map_err(convert_status) + .and_then(|res| async move { + Ok(deserialize_object_store_id( + res.object_store_id.required().ctx("object_store_id")?, + )) + }) + .try_collect() + .await + } + + async fn create_upgrade_delete( + &mut self, + partition_id: PartitionId, + delete: &[ObjectStoreId], + upgrade: &[ObjectStoreId], + create: &[ParquetFileParams], + target_level: CompactionLevel, + ) -> Result> { + let p = proto::ParquetFileCreateUpgradeDeleteRequest { + partition_id: partition_id.get(), + delete: delete + .iter() + .copied() + .map(serialize_object_store_id) + .collect(), + upgrade: upgrade + .iter() + .copied() + .map(serialize_object_store_id) + .collect(), + create: create.iter().map(serialize_parquet_file_params).collect(), + target_level: target_level as i32, + }; + + let resp = self.retry( + "parquet_file_create_upgrade_delete", + p, + |data, mut client| async move { client.parquet_file_create_upgrade_delete(data).await }, + ) + .await?; + + Ok(resp + .created_parquet_file_ids + .into_iter() + .map(ParquetFileId::new) + .collect()) + } +} diff --git a/iox_catalog/src/grpc/mod.rs b/iox_catalog/src/grpc/mod.rs new file mode 100644 index 0000000..0374f57 --- /dev/null +++ b/iox_catalog/src/grpc/mod.rs @@ -0,0 +1,143 @@ +//! gRPC catalog tunnel. +//! +//! This tunnels catalog requests over gRPC. + +pub mod client; +mod serialization; +pub mod server; + +#[cfg(test)] +mod tests { + use std::{net::SocketAddr, sync::Arc}; + + use data_types::NamespaceName; + use iox_time::SystemProvider; + use metric::{Attributes, Metric, U64Counter}; + use test_helpers::maybe_start_logging; + use tokio::{net::TcpListener, task::JoinSet}; + use tonic::transport::{server::TcpIncoming, Server, Uri}; + + use crate::{interface::Catalog, interface_tests::TestCatalog, mem::MemCatalog}; + + use super::*; + + #[tokio::test] + async fn test_catalog() { + maybe_start_logging(); + + crate::interface_tests::test_catalog(|| async { + let metrics = Arc::new(metric::Registry::default()); + let time_provider = Arc::new(SystemProvider::new()) as _; + let backing_catalog = Arc::new(MemCatalog::new(metrics, Arc::clone(&time_provider))); + let test_server = TestServer::new(backing_catalog).await; + let uri = test_server.uri(); + + // create new metrics for client so that they don't overlap w/ server + let metrics = Arc::new(metric::Registry::default()); + let client = Arc::new(client::GrpcCatalogClient::new( + uri, + metrics, + Arc::clone(&time_provider), + )); + + let test_catalog = TestCatalog::new(client); + test_catalog.hold_onto(test_server); + + Arc::new(test_catalog) as _ + }) + .await; + } + + #[tokio::test] + async fn test_catalog_metrics() { + maybe_start_logging(); + + let time_provider = Arc::new(SystemProvider::new()) as _; + let metrics = Arc::new(metric::Registry::default()); + let backing_catalog = Arc::new(MemCatalog::new(metrics, Arc::clone(&time_provider))); + let test_server = TestServer::new(backing_catalog).await; + let uri = test_server.uri(); + + // create new metrics for client so that they don't overlap w/ server + let metrics = Arc::new(metric::Registry::default()); + let client = Arc::new(client::GrpcCatalogClient::new( + uri, + Arc::clone(&metrics), + Arc::clone(&time_provider), + )); + + let ns = client + .repositories() + .namespaces() + .create(&NamespaceName::new("testns").unwrap(), None, None, None) + .await + .expect("namespace failed to create"); + + let _ = client + .repositories() + .tables() + .list_by_namespace_id(ns.id) + .await + .expect("failed to list namespaces"); + + let metric = metrics + .get_instrument::>("grpc_client_requests") + .expect("failed to get metric"); + + let count = metric + .get_observer(&Attributes::from(&[ + ( + "path", + "/influxdata.iox.catalog.v2.CatalogService/NamespaceCreate", + ), + ("status", "ok"), + ])) + .unwrap() + .fetch(); + + assert_eq!(count, 1); + + let count = metric + .get_observer(&Attributes::from(&[ + ( + "path", + "/influxdata.iox.catalog.v2.CatalogService/TableListByNamespaceId", + ), + ("status", "ok"), + ])) + .unwrap() + .fetch(); + + assert_eq!(count, 1); + } + + struct TestServer { + addr: SocketAddr, + #[allow(dead_code)] + task: JoinSet<()>, + } + + impl TestServer { + async fn new(catalog: Arc) -> Self { + let listener = TcpListener::bind("0.0.0.0:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + let mut task = JoinSet::new(); + task.spawn(async move { + Server::builder() + .add_service(server::GrpcCatalogServer::new(catalog).service()) + .serve_with_incoming(incoming) + .await + .unwrap(); + }); + + Self { addr, task } + } + + fn uri(&self) -> Uri { + format!("http://{}:{}", self.addr.ip(), self.addr.port()) + .parse() + .unwrap() + } + } +} diff --git a/iox_catalog/src/grpc/serialization.rs b/iox_catalog/src/grpc/serialization.rs new file mode 100644 index 0000000..2698dc4 --- /dev/null +++ b/iox_catalog/src/grpc/serialization.rs @@ -0,0 +1,712 @@ +use data_types::{ + partition_template::NamespacePartitionTemplateOverride, Column, ColumnId, ColumnSet, + ColumnType, Namespace, NamespaceId, ObjectStoreId, ParquetFile, ParquetFileId, + ParquetFileParams, Partition, PartitionId, SkippedCompaction, SortKeyIds, Table, TableId, + Timestamp, +}; +use generated_types::influxdata::iox::catalog::v2 as proto; +use uuid::Uuid; + +use crate::interface::SoftDeletedRows; + +#[derive(Debug)] +pub struct Error { + msg: String, + path: Vec<&'static str>, +} + +impl Error { + fn new(e: E) -> Self + where + E: std::fmt::Display, + { + Self { + msg: e.to_string(), + path: vec![], + } + } + + fn ctx(self, arg: &'static str) -> Self { + let Self { msg, mut path } = self; + path.insert(0, arg); + Self { msg, path } + } +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if !self.path.is_empty() { + write!(f, "{}", self.path[0])?; + for p in self.path.iter().skip(1) { + write!(f, ".{}", p)?; + } + write!(f, ": ")?; + } + + write!(f, "{}", self.msg)?; + + Ok(()) + } +} + +impl std::error::Error for Error {} + +impl From for crate::interface::Error { + fn from(e: Error) -> Self { + Self::External { source: e.into() } + } +} + +impl From for tonic::Status { + fn from(e: Error) -> Self { + Self::invalid_argument(e.to_string()) + } +} + +pub(crate) trait ConvertExt { + fn convert(self) -> Result; +} + +impl ConvertExt for T +where + T: TryInto, + T::Error: std::fmt::Display, +{ + fn convert(self) -> Result { + self.try_into().map_err(Error::new) + } +} + +pub(crate) trait ConvertOptExt { + fn convert_opt(self) -> Result; +} + +impl ConvertOptExt> for Option +where + T: TryInto, + T::Error: std::fmt::Display, +{ + fn convert_opt(self) -> Result, Error> { + self.map(|x| x.convert()).transpose() + } +} + +pub(crate) trait RequiredExt { + fn required(self) -> Result; +} + +impl RequiredExt for Option { + fn required(self) -> Result { + self.ok_or_else(|| Error::new("required")) + } +} + +pub(crate) trait ContextExt { + fn ctx(self, path: &'static str) -> Result; +} + +impl ContextExt for Result { + fn ctx(self, path: &'static str) -> Self { + self.map_err(|e| e.ctx(path)) + } +} + +pub(crate) fn catalog_error_to_status(e: crate::interface::Error) -> tonic::Status { + use crate::interface::Error; + + match e { + Error::External { source } => tonic::Status::internal(source.to_string()), + Error::AlreadyExists { descr } => tonic::Status::already_exists(descr), + Error::LimitExceeded { descr } => tonic::Status::resource_exhausted(descr), + Error::NotFound { descr } => tonic::Status::not_found(descr), + } +} + +pub(crate) fn convert_status(status: tonic::Status) -> crate::interface::Error { + use crate::interface::Error; + + match status.code() { + tonic::Code::Internal => Error::External { + source: status.message().to_owned().into(), + }, + tonic::Code::AlreadyExists => Error::AlreadyExists { + descr: status.message().to_owned(), + }, + tonic::Code::ResourceExhausted => Error::LimitExceeded { + descr: status.message().to_owned(), + }, + tonic::Code::NotFound => Error::NotFound { + descr: status.message().to_owned(), + }, + _ => Error::External { + source: Box::new(status), + }, + } +} + +pub(crate) fn serialize_soft_deleted_rows(sdr: SoftDeletedRows) -> i32 { + let sdr = match sdr { + SoftDeletedRows::AllRows => proto::SoftDeletedRows::AllRows, + SoftDeletedRows::ExcludeDeleted => proto::SoftDeletedRows::ExcludeDeleted, + SoftDeletedRows::OnlyDeleted => proto::SoftDeletedRows::OnlyDeleted, + }; + + sdr.into() +} + +pub(crate) fn deserialize_soft_deleted_rows(sdr: i32) -> Result { + let sdr: proto::SoftDeletedRows = sdr.convert().ctx("soft deleted rows")?; + let sdr = match sdr { + proto::SoftDeletedRows::Unspecified => { + return Err(Error::new("unspecified soft deleted rows")); + } + proto::SoftDeletedRows::AllRows => SoftDeletedRows::AllRows, + proto::SoftDeletedRows::ExcludeDeleted => SoftDeletedRows::ExcludeDeleted, + proto::SoftDeletedRows::OnlyDeleted => SoftDeletedRows::OnlyDeleted, + }; + Ok(sdr) +} + +pub(crate) fn serialize_namespace(ns: Namespace) -> proto::Namespace { + proto::Namespace { + id: ns.id.get(), + name: ns.name, + retention_period_ns: ns.retention_period_ns, + max_tables: ns.max_tables.get_i32(), + max_columns_per_table: ns.max_columns_per_table.get_i32(), + deleted_at: ns.deleted_at.map(|ts| ts.get()), + partition_template: ns.partition_template.as_proto().cloned(), + } +} + +pub(crate) fn deserialize_namespace(ns: proto::Namespace) -> Result { + Ok(Namespace { + id: NamespaceId::new(ns.id), + name: ns.name, + retention_period_ns: ns.retention_period_ns, + max_tables: ns.max_tables.convert().ctx("max_tables")?, + max_columns_per_table: ns + .max_columns_per_table + .convert() + .ctx("max_columns_per_table")?, + deleted_at: ns.deleted_at.map(Timestamp::new), + partition_template: ns + .partition_template + .convert_opt() + .ctx("partition_template")? + .unwrap_or_else(NamespacePartitionTemplateOverride::const_default), + }) +} + +pub(crate) fn serialize_table(t: Table) -> proto::Table { + proto::Table { + id: t.id.get(), + namespace_id: t.namespace_id.get(), + name: t.name, + partition_template: t.partition_template.as_proto().cloned(), + } +} + +pub(crate) fn deserialize_table(t: proto::Table) -> Result { + Ok(Table { + id: TableId::new(t.id), + namespace_id: NamespaceId::new(t.namespace_id), + name: t.name, + partition_template: t.partition_template.convert().ctx("partition_template")?, + }) +} + +pub(crate) fn serialize_column_type(t: ColumnType) -> i32 { + use generated_types::influxdata::iox::column_type::v1 as proto; + proto::ColumnType::from(t).into() +} + +pub(crate) fn deserialize_column_type(t: i32) -> Result { + use generated_types::influxdata::iox::column_type::v1 as proto; + let t: proto::ColumnType = t.convert()?; + t.convert() +} + +pub(crate) fn serialize_column(column: Column) -> proto::Column { + proto::Column { + id: column.id.get(), + table_id: column.table_id.get(), + name: column.name, + column_type: serialize_column_type(column.column_type), + } +} + +pub(crate) fn deserialize_column(column: proto::Column) -> Result { + Ok(Column { + id: ColumnId::new(column.id), + table_id: TableId::new(column.table_id), + name: column.name, + column_type: deserialize_column_type(column.column_type)?, + }) +} + +pub(crate) fn serialize_sort_key_ids(sort_key_ids: &SortKeyIds) -> proto::SortKeyIds { + proto::SortKeyIds { + column_ids: sort_key_ids.iter().map(|c_id| c_id.get()).collect(), + } +} + +pub(crate) fn deserialize_sort_key_ids(sort_key_ids: proto::SortKeyIds) -> SortKeyIds { + SortKeyIds::new(sort_key_ids.column_ids.into_iter().map(ColumnId::new)) +} + +pub(crate) fn serialize_partition(partition: Partition) -> proto::Partition { + let empty_sk = SortKeyIds::new(std::iter::empty()); + + proto::Partition { + id: partition.id.get(), + hash_id: partition + .hash_id() + .map(|id| id.as_bytes().to_vec()) + .unwrap_or_default(), + partition_key: partition.partition_key.inner().to_owned(), + table_id: partition.table_id.get(), + sort_key_ids: Some(serialize_sort_key_ids( + partition.sort_key_ids().unwrap_or(&empty_sk), + )), + new_file_at: partition.new_file_at.map(|ts| ts.get()), + } +} + +pub(crate) fn deserialize_partition(partition: proto::Partition) -> Result { + Ok(Partition::new_catalog_only( + PartitionId::new(partition.id), + (!partition.hash_id.is_empty()) + .then_some(partition.hash_id.as_slice()) + .convert_opt() + .ctx("hash_id")?, + TableId::new(partition.table_id), + partition.partition_key.into(), + deserialize_sort_key_ids(partition.sort_key_ids.required().ctx("sort_key_ids")?), + partition.new_file_at.map(Timestamp::new), + )) +} + +pub(crate) fn serialize_skipped_compaction(sc: SkippedCompaction) -> proto::SkippedCompaction { + proto::SkippedCompaction { + partition_id: sc.partition_id.get(), + reason: sc.reason, + skipped_at: sc.skipped_at.get(), + estimated_bytes: sc.estimated_bytes, + limit_bytes: sc.limit_bytes, + num_files: sc.num_files, + limit_num_files: sc.limit_num_files, + limit_num_files_first_in_partition: sc.limit_num_files_first_in_partition, + } +} + +pub(crate) fn deserialize_skipped_compaction(sc: proto::SkippedCompaction) -> SkippedCompaction { + SkippedCompaction { + partition_id: PartitionId::new(sc.partition_id), + reason: sc.reason, + skipped_at: Timestamp::new(sc.skipped_at), + estimated_bytes: sc.estimated_bytes, + limit_bytes: sc.limit_bytes, + num_files: sc.num_files, + limit_num_files: sc.limit_num_files, + limit_num_files_first_in_partition: sc.limit_num_files_first_in_partition, + } +} + +pub(crate) fn serialize_object_store_id(id: ObjectStoreId) -> proto::ObjectStoreId { + let (high64, low64) = id.get_uuid().as_u64_pair(); + proto::ObjectStoreId { high64, low64 } +} + +pub(crate) fn deserialize_object_store_id(id: proto::ObjectStoreId) -> ObjectStoreId { + ObjectStoreId::from_uuid(Uuid::from_u64_pair(id.high64, id.low64)) +} + +pub(crate) fn serialize_column_set(set: &ColumnSet) -> proto::ColumnSet { + proto::ColumnSet { + column_ids: set.iter().map(|id| id.get()).collect(), + } +} + +pub(crate) fn deserialize_column_set(set: proto::ColumnSet) -> ColumnSet { + ColumnSet::new(set.column_ids.into_iter().map(ColumnId::new)) +} + +pub(crate) fn serialize_parquet_file_params( + params: &ParquetFileParams, +) -> proto::ParquetFileParams { + proto::ParquetFileParams { + namespace_id: params.namespace_id.get(), + table_id: params.table_id.get(), + partition_id: params.partition_id.get(), + partition_hash_id: params + .partition_hash_id + .as_ref() + .map(|id| id.as_bytes().to_vec()), + object_store_id: Some(serialize_object_store_id(params.object_store_id)), + min_time: params.min_time.get(), + max_time: params.max_time.get(), + file_size_bytes: params.file_size_bytes, + row_count: params.row_count, + compaction_level: params.compaction_level as i32, + created_at: params.created_at.get(), + column_set: Some(serialize_column_set(¶ms.column_set)), + max_l0_created_at: params.max_l0_created_at.get(), + } +} + +pub(crate) fn deserialize_parquet_file_params( + params: proto::ParquetFileParams, +) -> Result { + Ok(ParquetFileParams { + namespace_id: NamespaceId::new(params.namespace_id), + table_id: TableId::new(params.table_id), + partition_id: PartitionId::new(params.partition_id), + partition_hash_id: params + .partition_hash_id + .as_deref() + .convert_opt() + .ctx("partition_hash_id")?, + object_store_id: deserialize_object_store_id( + params.object_store_id.required().ctx("object_store_id")?, + ), + min_time: Timestamp::new(params.min_time), + max_time: Timestamp::new(params.max_time), + file_size_bytes: params.file_size_bytes, + row_count: params.row_count, + compaction_level: params.compaction_level.convert().ctx("compaction_level")?, + created_at: Timestamp::new(params.created_at), + column_set: deserialize_column_set(params.column_set.required().ctx("column_set")?), + max_l0_created_at: Timestamp::new(params.max_l0_created_at), + }) +} + +pub(crate) fn serialize_parquet_file(file: ParquetFile) -> proto::ParquetFile { + let partition_hash_id = file + .partition_hash_id + .map(|x| x.as_bytes().to_vec()) + .unwrap_or_default(); + + proto::ParquetFile { + id: file.id.get(), + namespace_id: file.namespace_id.get(), + table_id: file.table_id.get(), + partition_id: file.partition_id.get(), + partition_hash_id, + object_store_id: Some(serialize_object_store_id(file.object_store_id)), + min_time: file.min_time.get(), + max_time: file.max_time.get(), + to_delete: file.to_delete.map(|ts| ts.get()), + file_size_bytes: file.file_size_bytes, + row_count: file.row_count, + compaction_level: file.compaction_level as i32, + created_at: file.created_at.get(), + column_set: Some(serialize_column_set(&file.column_set)), + max_l0_created_at: file.max_l0_created_at.get(), + } +} + +pub(crate) fn deserialize_parquet_file(file: proto::ParquetFile) -> Result { + let partition_hash_id = match file.partition_hash_id.as_slice() { + b"" => None, + s => Some(s.convert().ctx("partition_hash_id")?), + }; + + Ok(ParquetFile { + id: ParquetFileId::new(file.id), + namespace_id: NamespaceId::new(file.namespace_id), + table_id: TableId::new(file.table_id), + partition_id: PartitionId::new(file.partition_id), + partition_hash_id, + object_store_id: deserialize_object_store_id( + file.object_store_id.required().ctx("object_store_id")?, + ), + min_time: Timestamp::new(file.min_time), + max_time: Timestamp::new(file.max_time), + to_delete: file.to_delete.map(Timestamp::new), + file_size_bytes: file.file_size_bytes, + row_count: file.row_count, + compaction_level: file.compaction_level.convert().ctx("compaction_level")?, + created_at: Timestamp::new(file.created_at), + column_set: deserialize_column_set(file.column_set.required().ctx("column_set")?), + max_l0_created_at: Timestamp::new(file.max_l0_created_at), + }) +} + +#[cfg(test)] +mod tests { + use data_types::{ + partition_template::TablePartitionTemplateOverride, CompactionLevel, PartitionHashId, + PartitionKey, + }; + + use super::*; + + #[test] + fn test_column_type_roundtrip() { + assert_column_type_roundtrip(ColumnType::Bool); + assert_column_type_roundtrip(ColumnType::I64); + assert_column_type_roundtrip(ColumnType::U64); + assert_column_type_roundtrip(ColumnType::F64); + assert_column_type_roundtrip(ColumnType::String); + assert_column_type_roundtrip(ColumnType::Tag); + assert_column_type_roundtrip(ColumnType::Time); + } + + #[track_caller] + fn assert_column_type_roundtrip(t: ColumnType) { + let protobuf = serialize_column_type(t); + let t2 = deserialize_column_type(protobuf).unwrap(); + assert_eq!(t, t2); + } + + #[test] + fn test_error_roundtrip() { + use crate::interface::Error; + + assert_error_roundtrip(Error::AlreadyExists { + descr: "foo".to_owned(), + }); + assert_error_roundtrip(Error::External { + source: "foo".to_owned().into(), + }); + assert_error_roundtrip(Error::LimitExceeded { + descr: "foo".to_owned(), + }); + assert_error_roundtrip(Error::NotFound { + descr: "foo".to_owned(), + }); + } + + #[track_caller] + fn assert_error_roundtrip(e: crate::interface::Error) { + let msg_orig = e.to_string(); + + let status = catalog_error_to_status(e); + let e = convert_status(status); + let msg = e.to_string(); + assert_eq!(msg, msg_orig); + } + + #[test] + fn test_soft_deleted_rows_roundtrip() { + assert_soft_deleted_rows_roundtrip(SoftDeletedRows::AllRows); + assert_soft_deleted_rows_roundtrip(SoftDeletedRows::ExcludeDeleted); + assert_soft_deleted_rows_roundtrip(SoftDeletedRows::OnlyDeleted); + } + + #[track_caller] + fn assert_soft_deleted_rows_roundtrip(sdr: SoftDeletedRows) { + let protobuf = serialize_soft_deleted_rows(sdr); + let sdr2 = deserialize_soft_deleted_rows(protobuf).unwrap(); + assert_eq!(sdr, sdr2); + } + + #[test] + fn test_namespace_roundtrip() { + use generated_types::influxdata::iox::partition_template::v1 as proto; + + let ns = Namespace { + id: NamespaceId::new(1), + name: "ns".to_owned(), + retention_period_ns: Some(2), + max_tables: 3.try_into().unwrap(), + max_columns_per_table: 4.try_into().unwrap(), + deleted_at: Some(Timestamp::new(5)), + partition_template: NamespacePartitionTemplateOverride::try_from( + proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat("year-%Y".into())), + }], + }, + ) + .unwrap(), + }; + let protobuf = serialize_namespace(ns.clone()); + let ns2 = deserialize_namespace(protobuf).unwrap(); + assert_eq!(ns, ns2); + } + + #[test] + fn test_table_roundtrip() { + use generated_types::influxdata::iox::partition_template::v1 as proto; + + let table = Table { + id: TableId::new(1), + namespace_id: NamespaceId::new(2), + name: "table".to_owned(), + partition_template: TablePartitionTemplateOverride::try_new( + Some(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat("year-%Y".into())), + }], + }), + &NamespacePartitionTemplateOverride::const_default(), + ) + .unwrap(), + }; + let protobuf = serialize_table(table.clone()); + let table2 = deserialize_table(protobuf).unwrap(); + assert_eq!(table, table2); + } + + #[test] + fn test_column_roundtrip() { + let column = Column { + id: ColumnId::new(1), + table_id: TableId::new(2), + name: "col".to_owned(), + column_type: ColumnType::F64, + }; + let protobuf = serialize_column(column.clone()); + let column2 = deserialize_column(protobuf).unwrap(); + assert_eq!(column, column2); + } + + #[test] + fn test_sort_key_ids_roundtrip() { + assert_sort_key_ids_roundtrip(SortKeyIds::new(std::iter::empty())); + assert_sort_key_ids_roundtrip(SortKeyIds::new([ColumnId::new(1)])); + assert_sort_key_ids_roundtrip(SortKeyIds::new([ + ColumnId::new(1), + ColumnId::new(5), + ColumnId::new(20), + ])); + } + + #[track_caller] + fn assert_sort_key_ids_roundtrip(sort_key_ids: SortKeyIds) { + let protobuf = serialize_sort_key_ids(&sort_key_ids); + let sort_key_ids2 = deserialize_sort_key_ids(protobuf); + assert_eq!(sort_key_ids, sort_key_ids2); + } + + #[test] + fn test_partition_roundtrip() { + let table_id = TableId::new(1); + let partition_key = PartitionKey::from("key"); + let hash_id = PartitionHashId::new(table_id, &partition_key); + + assert_partition_roundtrip(Partition::new_catalog_only( + PartitionId::new(2), + Some(hash_id.clone()), + table_id, + partition_key.clone(), + SortKeyIds::new([ColumnId::new(3), ColumnId::new(4)]), + Some(Timestamp::new(5)), + )); + assert_partition_roundtrip(Partition::new_catalog_only( + PartitionId::new(2), + Some(hash_id), + table_id, + partition_key, + SortKeyIds::new(std::iter::empty()), + Some(Timestamp::new(5)), + )); + } + + #[track_caller] + fn assert_partition_roundtrip(partition: Partition) { + let protobuf = serialize_partition(partition.clone()); + let partition2 = deserialize_partition(protobuf).unwrap(); + assert_eq!(partition, partition2); + } + + #[test] + fn test_skipped_compaction_roundtrip() { + let sc = SkippedCompaction { + partition_id: PartitionId::new(1), + reason: "foo".to_owned(), + skipped_at: Timestamp::new(2), + estimated_bytes: 3, + limit_bytes: 4, + num_files: 5, + limit_num_files: 6, + limit_num_files_first_in_partition: 7, + }; + let protobuf = serialize_skipped_compaction(sc.clone()); + let sc2 = deserialize_skipped_compaction(protobuf); + assert_eq!(sc, sc2); + } + + #[test] + fn test_object_store_id_roundtrip() { + assert_object_store_id_roundtrip(ObjectStoreId::from_uuid(Uuid::nil())); + assert_object_store_id_roundtrip(ObjectStoreId::from_uuid(Uuid::from_u128(0))); + assert_object_store_id_roundtrip(ObjectStoreId::from_uuid(Uuid::from_u128(u128::MAX))); + assert_object_store_id_roundtrip(ObjectStoreId::from_uuid(Uuid::from_u128(1))); + assert_object_store_id_roundtrip(ObjectStoreId::from_uuid(Uuid::from_u128(u128::MAX - 1))); + } + + #[track_caller] + fn assert_object_store_id_roundtrip(id: ObjectStoreId) { + let protobuf = serialize_object_store_id(id); + let id2 = deserialize_object_store_id(protobuf); + assert_eq!(id, id2); + } + + #[test] + fn test_column_set_roundtrip() { + assert_column_set_roundtrip(ColumnSet::new([])); + assert_column_set_roundtrip(ColumnSet::new([ColumnId::new(1)])); + assert_column_set_roundtrip(ColumnSet::new([ColumnId::new(1), ColumnId::new(10)])); + assert_column_set_roundtrip(ColumnSet::new([ + ColumnId::new(3), + ColumnId::new(4), + ColumnId::new(10), + ])); + } + + #[track_caller] + fn assert_column_set_roundtrip(set: ColumnSet) { + let protobuf = serialize_column_set(&set); + let set2 = deserialize_column_set(protobuf); + assert_eq!(set, set2); + } + + #[test] + fn test_parquet_file_params_roundtrip() { + let params = ParquetFileParams { + namespace_id: NamespaceId::new(1), + table_id: TableId::new(2), + partition_id: PartitionId::new(3), + partition_hash_id: Some(PartitionHashId::arbitrary_for_testing()), + object_store_id: ObjectStoreId::from_uuid(Uuid::from_u128(1337)), + min_time: Timestamp::new(4), + max_time: Timestamp::new(5), + file_size_bytes: 6, + row_count: 7, + compaction_level: CompactionLevel::Final, + created_at: Timestamp::new(8), + column_set: ColumnSet::new([ColumnId::new(9), ColumnId::new(10)]), + max_l0_created_at: Timestamp::new(11), + }; + let protobuf = serialize_parquet_file_params(¶ms); + let params2 = deserialize_parquet_file_params(protobuf).unwrap(); + assert_eq!(params, params2); + } + + #[test] + fn test_parquet_file_roundtrip() { + let file = ParquetFile { + id: ParquetFileId::new(12), + namespace_id: NamespaceId::new(1), + table_id: TableId::new(2), + partition_id: PartitionId::new(3), + partition_hash_id: Some(PartitionHashId::arbitrary_for_testing()), + object_store_id: ObjectStoreId::from_uuid(Uuid::from_u128(1337)), + min_time: Timestamp::new(4), + max_time: Timestamp::new(5), + to_delete: Some(Timestamp::new(13)), + file_size_bytes: 6, + row_count: 7, + compaction_level: CompactionLevel::Final, + created_at: Timestamp::new(8), + column_set: ColumnSet::new([ColumnId::new(9), ColumnId::new(10)]), + max_l0_created_at: Timestamp::new(11), + }; + let protobuf = serialize_parquet_file(file.clone()); + let file2 = deserialize_parquet_file(protobuf).unwrap(); + assert_eq!(file, file2); + } +} diff --git a/iox_catalog/src/grpc/server.rs b/iox_catalog/src/grpc/server.rs new file mode 100644 index 0000000..2105457 --- /dev/null +++ b/iox_catalog/src/grpc/server.rs @@ -0,0 +1,1032 @@ +//! gRPC server implementation. + +use std::{pin::Pin, sync::Arc}; + +use crate::{ + grpc::serialization::{ + catalog_error_to_status, deserialize_column_type, deserialize_object_store_id, + deserialize_parquet_file_params, deserialize_soft_deleted_rows, deserialize_sort_key_ids, + serialize_column, serialize_namespace, serialize_object_store_id, serialize_parquet_file, + serialize_partition, serialize_skipped_compaction, serialize_sort_key_ids, serialize_table, + ContextExt, ConvertExt, ConvertOptExt, RequiredExt, + }, + interface::{CasFailure, Catalog}, +}; +use async_trait::async_trait; +use data_types::{ + NamespaceId, NamespaceServiceProtectionLimitsOverride, PartitionId, PartitionKey, TableId, + Timestamp, +}; +use futures::{Stream, StreamExt, TryStreamExt}; +use generated_types::influxdata::iox::catalog::v2 as proto; +use generated_types::influxdata::iox::catalog::v2::{TableSnapshotRequest, TableSnapshotResponse}; +use tonic::{Request, Response, Status}; + +type TonicStream = Pin> + Send + 'static>>; + +/// gRPC server. +#[derive(Debug)] +pub struct GrpcCatalogServer { + catalog: Arc, +} + +impl GrpcCatalogServer { + /// Create a new [`GrpcCatalogServer`]. + pub fn new(catalog: Arc) -> Self { + Self { catalog } + } + + /// Get service for integration w/ tonic. + pub fn service(&self) -> proto::catalog_service_server::CatalogServiceServer { + let this = Self { + catalog: Arc::clone(&self.catalog), + }; + proto::catalog_service_server::CatalogServiceServer::new(this) + } +} + +#[async_trait] +impl proto::catalog_service_server::CatalogService for GrpcCatalogServer { + type NamespaceListStream = TonicStream; + + type TableListByNamespaceIdStream = TonicStream; + type TableListStream = TonicStream; + + type ColumnCreateOrGetManyUncheckedStream = + TonicStream; + type ColumnListByNamespaceIdStream = TonicStream; + type ColumnListByTableIdStream = TonicStream; + type ColumnListStream = TonicStream; + + type PartitionGetByIdBatchStream = TonicStream; + type PartitionListByTableIdStream = TonicStream; + type PartitionListIdsStream = TonicStream; + type PartitionGetInSkippedCompactionsStream = + TonicStream; + type PartitionListSkippedCompactionsStream = + TonicStream; + type PartitionMostRecentNStream = TonicStream; + type PartitionNewFileBetweenStream = TonicStream; + type PartitionListOldStyleStream = TonicStream; + + type ParquetFileFlagForDeleteByRetentionStream = + TonicStream; + type ParquetFileDeleteOldIdsOnlyStream = + TonicStream; + type ParquetFileListByPartitionNotToDeleteBatchStream = + TonicStream; + type ParquetFileExistsByObjectStoreIdBatchStream = + TonicStream; + + async fn namespace_create( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + let ns = self + .catalog + .repositories() + .namespaces() + .create( + &req.name.convert().ctx("name")?, + req.partition_template + .convert_opt() + .ctx("partition_template")?, + req.retention_period_ns, + req.service_protection_limits + .map(|l| { + let l = NamespaceServiceProtectionLimitsOverride { + max_tables: l.max_tables.convert_opt().ctx("max_tables")?, + max_columns_per_table: l + .max_columns_per_table + .convert_opt() + .ctx("max_columns_per_table")?, + }; + Ok(l) as Result<_, tonic::Status> + }) + .transpose()?, + ) + .await + .map_err(catalog_error_to_status)?; + + let ns = serialize_namespace(ns); + + Ok(Response::new(proto::NamespaceCreateResponse { + namespace: Some(ns), + })) + } + + async fn namespace_update_retention_period( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + let ns = self + .catalog + .repositories() + .namespaces() + .update_retention_period(&req.name, req.retention_period_ns) + .await + .map_err(catalog_error_to_status)?; + + let ns = serialize_namespace(ns); + + Ok(Response::new( + proto::NamespaceUpdateRetentionPeriodResponse { + namespace: Some(ns), + }, + )) + } + + async fn namespace_list( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + let deleted = deserialize_soft_deleted_rows(req.deleted)?; + + let ns_list = self + .catalog + .repositories() + .namespaces() + .list(deleted) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + futures::stream::iter(ns_list.into_iter().map(|ns| { + let ns = serialize_namespace(ns); + + Ok(proto::NamespaceListResponse { + namespace: Some(ns), + }) + })) + .boxed(), + )) + } + + async fn namespace_get_by_id( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + let deleted = deserialize_soft_deleted_rows(req.deleted)?; + + let maybe_ns = self + .catalog + .repositories() + .namespaces() + .get_by_id(NamespaceId::new(req.id), deleted) + .await + .map_err(catalog_error_to_status)?; + + let maybe_ns = maybe_ns.map(serialize_namespace); + + Ok(Response::new(proto::NamespaceGetByIdResponse { + namespace: maybe_ns, + })) + } + + async fn namespace_get_by_name( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + let deleted = deserialize_soft_deleted_rows(req.deleted)?; + + let maybe_ns = self + .catalog + .repositories() + .namespaces() + .get_by_name(&req.name, deleted) + .await + .map_err(catalog_error_to_status)?; + + let maybe_ns = maybe_ns.map(serialize_namespace); + + Ok(Response::new(proto::NamespaceGetByNameResponse { + namespace: maybe_ns, + })) + } + + async fn namespace_soft_delete( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + self.catalog + .repositories() + .namespaces() + .soft_delete(&req.name) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new(proto::NamespaceSoftDeleteResponse {})) + } + + async fn namespace_update_table_limit( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + let ns = self + .catalog + .repositories() + .namespaces() + .update_table_limit(&req.name, req.new_max.convert().ctx("new_max")?) + .await + .map_err(catalog_error_to_status)?; + + let ns = serialize_namespace(ns); + + Ok(Response::new(proto::NamespaceUpdateTableLimitResponse { + namespace: Some(ns), + })) + } + + async fn namespace_update_column_limit( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + let ns = self + .catalog + .repositories() + .namespaces() + .update_column_limit(&req.name, req.new_max.convert().ctx("new_max")?) + .await + .map_err(catalog_error_to_status)?; + + let ns = serialize_namespace(ns); + + Ok(Response::new(proto::NamespaceUpdateColumnLimitResponse { + namespace: Some(ns), + })) + } + + async fn table_create( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + let table = self + .catalog + .repositories() + .tables() + .create( + &req.name, + req.partition_template.convert().ctx("partition_template")?, + NamespaceId::new(req.namespace_id), + ) + .await + .map_err(catalog_error_to_status)?; + + let table = serialize_table(table); + + Ok(Response::new(proto::TableCreateResponse { + table: Some(table), + })) + } + + async fn table_get_by_id( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + let maybe_table = self + .catalog + .repositories() + .tables() + .get_by_id(TableId::new(req.id)) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new(proto::TableGetByIdResponse { + table: maybe_table.map(serialize_table), + })) + } + + async fn table_get_by_namespace_and_name( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + let maybe_table = self + .catalog + .repositories() + .tables() + .get_by_namespace_and_name(NamespaceId::new(req.namespace_id), &req.name) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new(proto::TableGetByNamespaceAndNameResponse { + table: maybe_table.map(serialize_table), + })) + } + + async fn table_list_by_namespace_id( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + let table_list = self + .catalog + .repositories() + .tables() + .list_by_namespace_id(NamespaceId::new(req.namespace_id)) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + futures::stream::iter(table_list.into_iter().map(|table| { + let table = serialize_table(table); + Ok(proto::TableListByNamespaceIdResponse { table: Some(table) }) + })) + .boxed(), + )) + } + + async fn table_list( + &self, + _request: Request, + ) -> Result, tonic::Status> { + let table_list = self + .catalog + .repositories() + .tables() + .list() + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + futures::stream::iter(table_list.into_iter().map(|table| { + let table = serialize_table(table); + Ok(proto::TableListResponse { table: Some(table) }) + })) + .boxed(), + )) + } + + async fn table_snapshot( + &self, + request: Request, + ) -> Result, Status> { + let req = request.into_inner(); + let snapshot = self + .catalog + .repositories() + .tables() + .snapshot(TableId::new(req.table_id)) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new(TableSnapshotResponse { + generation: snapshot.generation(), + table: Some(snapshot.into()), + })) + } + + async fn column_create_or_get( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + let column_type = deserialize_column_type(req.column_type)?; + + let column = self + .catalog + .repositories() + .columns() + .create_or_get(&req.name, TableId::new(req.table_id), column_type) + .await + .map_err(catalog_error_to_status)?; + + let column = serialize_column(column); + + Ok(Response::new(proto::ColumnCreateOrGetResponse { + column: Some(column), + })) + } + + async fn column_create_or_get_many_unchecked( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + let columns = req + .columns + .iter() + .map(|(name, t)| { + let t = deserialize_column_type(*t)?; + Ok((name.as_str(), t)) + }) + .collect::>()?; + + let column_list = self + .catalog + .repositories() + .columns() + .create_or_get_many_unchecked(TableId::new(req.table_id), columns) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + futures::stream::iter(column_list.into_iter().map(|column| { + let column = serialize_column(column); + Ok(proto::ColumnCreateOrGetManyUncheckedResponse { + column: Some(column), + }) + })) + .boxed(), + )) + } + + async fn column_list_by_namespace_id( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + let column_list = self + .catalog + .repositories() + .columns() + .list_by_namespace_id(NamespaceId::new(req.namespace_id)) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + futures::stream::iter(column_list.into_iter().map(|column| { + let column = serialize_column(column); + Ok(proto::ColumnListByNamespaceIdResponse { + column: Some(column), + }) + })) + .boxed(), + )) + } + + async fn column_list_by_table_id( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + let column_list = self + .catalog + .repositories() + .columns() + .list_by_table_id(TableId::new(req.table_id)) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + futures::stream::iter(column_list.into_iter().map(|column| { + let column = serialize_column(column); + Ok(proto::ColumnListByTableIdResponse { + column: Some(column), + }) + })) + .boxed(), + )) + } + + async fn column_list( + &self, + _request: Request, + ) -> Result, tonic::Status> { + let column_list = self + .catalog + .repositories() + .columns() + .list() + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + futures::stream::iter(column_list.into_iter().map(|column| { + let column = serialize_column(column); + Ok(proto::ColumnListResponse { + column: Some(column), + }) + })) + .boxed(), + )) + } + + async fn partition_create_or_get( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + let partition = self + .catalog + .repositories() + .partitions() + .create_or_get(PartitionKey::from(req.key), TableId::new(req.table_id)) + .await + .map_err(catalog_error_to_status)?; + + let partition = serialize_partition(partition); + + Ok(Response::new(proto::PartitionCreateOrGetResponse { + partition: Some(partition), + })) + } + + async fn partition_get_by_id_batch( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + let partition_ids = req + .partition_ids + .into_iter() + .map(PartitionId::new) + .collect::>(); + + let partition_list = self + .catalog + .repositories() + .partitions() + .get_by_id_batch(&partition_ids) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + futures::stream::iter(partition_list.into_iter().map(|partition| { + let partition = serialize_partition(partition); + Ok(proto::PartitionGetByIdBatchResponse { + partition: Some(partition), + }) + })) + .boxed(), + )) + } + + async fn partition_list_by_table_id( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + let partition_list = self + .catalog + .repositories() + .partitions() + .list_by_table_id(TableId::new(req.table_id)) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + futures::stream::iter(partition_list.into_iter().map(|partition| { + let partition = serialize_partition(partition); + Ok(proto::PartitionListByTableIdResponse { + partition: Some(partition), + }) + })) + .boxed(), + )) + } + + async fn partition_list_ids( + &self, + _request: Request, + ) -> Result, tonic::Status> { + let id_list = self + .catalog + .repositories() + .partitions() + .list_ids() + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + futures::stream::iter(id_list.into_iter().map(|id| { + Ok(proto::PartitionListIdsResponse { + partition_id: id.get(), + }) + })) + .boxed(), + )) + } + + async fn partition_cas_sort_key( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + let res = self + .catalog + .repositories() + .partitions() + .cas_sort_key( + PartitionId::new(req.partition_id), + req.old_sort_key_ids.map(deserialize_sort_key_ids).as_ref(), + &deserialize_sort_key_ids(req.new_sort_key_ids.required().ctx("new_sort_key_ids")?), + ) + .await; + + match res { + Ok(partition) => Ok(Response::new(proto::PartitionCasSortKeyResponse { + res: Some(proto::partition_cas_sort_key_response::Res::Partition( + serialize_partition(partition), + )), + })), + Err(CasFailure::ValueMismatch(sort_key_ids)) => { + Ok(Response::new(proto::PartitionCasSortKeyResponse { + res: Some(proto::partition_cas_sort_key_response::Res::CurrentSortKey( + serialize_sort_key_ids(&sort_key_ids), + )), + })) + } + Err(CasFailure::QueryError(e)) => Err(catalog_error_to_status(e)), + } + } + + async fn partition_record_skipped_compaction( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + self.catalog + .repositories() + .partitions() + .record_skipped_compaction( + PartitionId::new(req.partition_id), + &req.reason, + req.num_files as usize, + req.limit_num_files as usize, + req.limit_num_files_first_in_partition as usize, + req.estimated_bytes, + req.limit_bytes, + ) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + proto::PartitionRecordSkippedCompactionResponse {}, + )) + } + + async fn partition_get_in_skipped_compactions( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + let partition_ids = req + .partition_ids + .into_iter() + .map(PartitionId::new) + .collect::>(); + + let skipped_compaction_list = self + .catalog + .repositories() + .partitions() + .get_in_skipped_compactions(&partition_ids) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + futures::stream::iter(skipped_compaction_list.into_iter().map(|sc| { + let sc = serialize_skipped_compaction(sc); + Ok(proto::PartitionGetInSkippedCompactionsResponse { + skipped_compaction: Some(sc), + }) + })) + .boxed(), + )) + } + + async fn partition_list_skipped_compactions( + &self, + _request: Request, + ) -> Result, tonic::Status> { + let skipped_compaction_list = self + .catalog + .repositories() + .partitions() + .list_skipped_compactions() + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + futures::stream::iter(skipped_compaction_list.into_iter().map(|sc| { + let sc = serialize_skipped_compaction(sc); + Ok(proto::PartitionListSkippedCompactionsResponse { + skipped_compaction: Some(sc), + }) + })) + .boxed(), + )) + } + + async fn partition_delete_skipped_compactions( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + let maybe_skipped_compaction = self + .catalog + .repositories() + .partitions() + .delete_skipped_compactions(PartitionId::new(req.partition_id)) + .await + .map_err(catalog_error_to_status)?; + + let maybe_skipped_compaction = maybe_skipped_compaction.map(serialize_skipped_compaction); + + Ok(Response::new( + proto::PartitionDeleteSkippedCompactionsResponse { + skipped_compaction: maybe_skipped_compaction, + }, + )) + } + + async fn partition_most_recent_n( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + let partition_list = self + .catalog + .repositories() + .partitions() + .most_recent_n(req.n as usize) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + futures::stream::iter(partition_list.into_iter().map(|partition| { + let partition = serialize_partition(partition); + Ok(proto::PartitionMostRecentNResponse { + partition: Some(partition), + }) + })) + .boxed(), + )) + } + + async fn partition_new_file_between( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + let id_list = self + .catalog + .repositories() + .partitions() + .partitions_new_file_between( + Timestamp::new(req.minimum_time), + req.maximum_time.map(Timestamp::new), + ) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + futures::stream::iter(id_list.into_iter().map(|id| { + Ok(proto::PartitionNewFileBetweenResponse { + partition_id: id.get(), + }) + })) + .boxed(), + )) + } + + async fn partition_list_old_style( + &self, + _request: Request, + ) -> Result, tonic::Status> { + let partition_list = self + .catalog + .repositories() + .partitions() + .list_old_style() + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + futures::stream::iter(partition_list.into_iter().map(|partition| { + let partition = serialize_partition(partition); + Ok(proto::PartitionListOldStyleResponse { + partition: Some(partition), + }) + })) + .boxed(), + )) + } + + async fn partition_snapshot( + &self, + request: Request, + ) -> Result, Status> { + let req = request.into_inner(); + let snapshot = self + .catalog + .repositories() + .partitions() + .snapshot(PartitionId::new(req.partition_id)) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new(proto::PartitionSnapshotResponse { + generation: snapshot.generation(), + partition: Some(snapshot.into()), + })) + } + + async fn parquet_file_flag_for_delete_by_retention( + &self, + _request: Request, + ) -> Result, tonic::Status> { + let id_list = self + .catalog + .repositories() + .parquet_files() + .flag_for_delete_by_retention() + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + futures::stream::iter(id_list.into_iter().map(|(p_id, os_id)| { + let object_store_id = serialize_object_store_id(os_id); + Ok(proto::ParquetFileFlagForDeleteByRetentionResponse { + partition_id: p_id.get(), + object_store_id: Some(object_store_id), + }) + })) + .boxed(), + )) + } + + async fn parquet_file_delete_old_ids_only( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + let id_list = self + .catalog + .repositories() + .parquet_files() + .delete_old_ids_only(Timestamp::new(req.older_than)) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + futures::stream::iter(id_list.into_iter().map(|id| { + let object_store_id = serialize_object_store_id(id); + Ok(proto::ParquetFileDeleteOldIdsOnlyResponse { + object_store_id: Some(object_store_id), + }) + })) + .boxed(), + )) + } + + async fn parquet_file_list_by_partition_not_to_delete_batch( + &self, + request: Request, + ) -> Result, tonic::Status> + { + let req = request.into_inner(); + let partition_ids = req + .partition_ids + .into_iter() + .map(PartitionId::new) + .collect::>(); + + let file_list = self + .catalog + .repositories() + .parquet_files() + .list_by_partition_not_to_delete_batch(partition_ids) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + futures::stream::iter(file_list.into_iter().map(|file| { + let file = serialize_parquet_file(file); + Ok(proto::ParquetFileListByPartitionNotToDeleteBatchResponse { + parquet_file: Some(file), + }) + })) + .boxed(), + )) + } + + async fn parquet_file_get_by_object_store_id( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + + let maybe_file = self + .catalog + .repositories() + .parquet_files() + .get_by_object_store_id(deserialize_object_store_id( + req.object_store_id.required().ctx("object_store_id")?, + )) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + proto::ParquetFileGetByObjectStoreIdResponse { + parquet_file: maybe_file.map(serialize_parquet_file), + }, + )) + } + + async fn parquet_file_exists_by_object_store_id_batch( + &self, + request: Request>, + ) -> Result, tonic::Status> { + let object_store_ids = request + .into_inner() + .map_err(|e| tonic::Status::invalid_argument(e.to_string())) + .and_then(|req| async move { + Ok(deserialize_object_store_id( + req.object_store_id.required().ctx("object_store_id")?, + )) + }) + .try_collect::>() + .await?; + + let id_list = self + .catalog + .repositories() + .parquet_files() + .exists_by_object_store_id_batch(object_store_ids) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + futures::stream::iter(id_list.into_iter().map(|id| { + let object_store_id = serialize_object_store_id(id); + Ok(proto::ParquetFileExistsByObjectStoreIdBatchResponse { + object_store_id: Some(object_store_id), + }) + })) + .boxed(), + )) + } + + async fn parquet_file_create_upgrade_delete( + &self, + request: Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + let delete = req + .delete + .into_iter() + .map(deserialize_object_store_id) + .collect::>(); + let upgrade = req + .upgrade + .into_iter() + .map(deserialize_object_store_id) + .collect::>(); + let create = req + .create + .into_iter() + .map(deserialize_parquet_file_params) + .collect::, _>>()?; + + let id_list = self + .catalog + .repositories() + .parquet_files() + .create_upgrade_delete( + PartitionId::new(req.partition_id), + &delete, + &upgrade, + &create, + req.target_level.convert().ctx("target_level")?, + ) + .await + .map_err(catalog_error_to_status)?; + + Ok(Response::new( + proto::ParquetFileCreateUpgradeDeleteResponse { + created_parquet_file_ids: id_list.into_iter().map(|id| id.get()).collect(), + }, + )) + } +} diff --git a/iox_catalog/src/interface.rs b/iox_catalog/src/interface.rs new file mode 100644 index 0000000..dae33a2 --- /dev/null +++ b/iox_catalog/src/interface.rs @@ -0,0 +1,490 @@ +//! Traits and data types for the IOx Catalog API. + +use async_trait::async_trait; +use data_types::snapshot::partition::PartitionSnapshot; +use data_types::snapshot::table::TableSnapshot; +use data_types::{ + partition_template::{NamespacePartitionTemplateOverride, TablePartitionTemplateOverride}, + Column, ColumnType, CompactionLevel, MaxColumnsPerTable, MaxTables, Namespace, NamespaceId, + NamespaceName, NamespaceServiceProtectionLimitsOverride, ObjectStoreId, ParquetFile, + ParquetFileId, ParquetFileParams, Partition, PartitionId, PartitionKey, SkippedCompaction, + SortKeyIds, Table, TableId, Timestamp, +}; +use iox_time::TimeProvider; +use snafu::Snafu; +use std::{ + collections::HashMap, + fmt::{Debug, Display}, + sync::Arc, +}; + +/// An error wrapper detailing the reason for a compare-and-swap failure. +#[derive(Debug)] +pub enum CasFailure { + /// The compare-and-swap failed because the current value differers from the + /// comparator. + /// + /// Contains the new current value. + ValueMismatch(T), + /// A query error occurred. + QueryError(Error), +} + +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +#[snafu(visibility(pub(crate)))] +pub enum Error { + #[snafu(display("unhandled external error: {source}"))] + External { + source: Box, + }, + + #[snafu(display("already exists: {descr}"))] + AlreadyExists { descr: String }, + + #[snafu(display("limit exceeded: {descr}"))] + LimitExceeded { descr: String }, + + #[snafu(display("not found: {descr}"))] + NotFound { descr: String }, +} + +impl From for Error { + fn from(e: sqlx::Error) -> Self { + Self::External { + source: Box::new(e), + } + } +} + +impl From for Error { + fn from(e: sqlx::migrate::MigrateError) -> Self { + Self::from(sqlx::Error::from(e)) + } +} + +impl From for Error { + fn from(e: data_types::snapshot::partition::Error) -> Self { + Self::External { + source: Box::new(e), + } + } +} + +impl From for Error { + fn from(e: data_types::snapshot::table::Error) -> Self { + Self::External { + source: Box::new(e), + } + } +} + +impl From for Error { + fn from(e: catalog_cache::api::quorum::Error) -> Self { + Self::External { + source: Box::new(e), + } + } +} + +impl From for Error { + fn from(e: generated_types::prost::DecodeError) -> Self { + Self::External { + source: Box::new(e), + } + } +} + +/// A specialized `Error` for Catalog errors +pub type Result = std::result::Result; + +/// Specify how soft-deleted entities should affect query results. +/// +/// ```text +/// +/// ExcludeDeleted OnlyDeleted +/// +/// ┃ ┃ +/// .─────╋─────. .─────╋─────. +/// ,─' ┃ '─. ,─' ┃ '─. +/// ,' ● `,' ● `. +/// ,' ,' `. `. +/// ; ; : : +/// │ No deleted │ │ Only deleted │ +/// │ rows │ ● │ rows │ +/// : : ┃ ; ; +/// ╲ ╲ ┃ ╱ ╱ +/// `. `┃' ,' +/// `. ,'┃`. ,' +/// '─. ,─' ┃ '─. ,─' +/// `─────────' ┃ `─────────' +/// ┃ +/// +/// AllRows +/// +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SoftDeletedRows { + /// Return all rows. + AllRows, + + /// Return all rows, except soft deleted rows. + ExcludeDeleted, + + /// Return only soft deleted rows. + OnlyDeleted, +} + +impl SoftDeletedRows { + pub(crate) fn as_sql_predicate(&self) -> &str { + match self { + Self::ExcludeDeleted => "deleted_at IS NULL", + Self::OnlyDeleted => "deleted_at IS NOT NULL", + Self::AllRows => "1=1", + } + } +} + +/// Methods for working with the catalog. +#[async_trait] +pub trait Catalog: Send + Sync + Debug + Display { + /// Setup catalog for usage and apply possible migrations. + async fn setup(&self) -> Result<(), Error>; + + /// Accesses the repositories without a transaction scope. + fn repositories(&self) -> Box; + + /// Gets metric registry associated with this catalog for testing purposes. + #[cfg(test)] + fn metrics(&self) -> Arc; + + /// Gets the time provider associated with this catalog. + fn time_provider(&self) -> Arc; +} + +/// Methods for working with the catalog's various repositories (collections of entities). +/// +/// # Repositories +/// +/// The methods (e.g. `create_*` or `get_by_*`) for handling entities (namespaces, partitions, +/// etc.) are grouped into *repositories* with one repository per entity. A repository can be +/// thought of a collection of a single kind of entity. Getting repositories from the transaction +/// is cheap. +/// +/// A repository might internally map to a wide range of different storage abstractions, ranging +/// from one or more SQL tables over key-value key spaces to simple in-memory vectors. The user +/// should and must not care how these are implemented. +pub trait RepoCollection: Send + Sync + Debug { + /// Repository for [namespaces](data_types::Namespace). + fn namespaces(&mut self) -> &mut dyn NamespaceRepo; + + /// Repository for [tables](data_types::Table). + fn tables(&mut self) -> &mut dyn TableRepo; + + /// Repository for [columns](data_types::Column). + fn columns(&mut self) -> &mut dyn ColumnRepo; + + /// Repository for [partitions](data_types::Partition). + fn partitions(&mut self) -> &mut dyn PartitionRepo; + + /// Repository for [Parquet files](data_types::ParquetFile). + fn parquet_files(&mut self) -> &mut dyn ParquetFileRepo; +} + +/// Functions for working with namespaces in the catalog +#[async_trait] +pub trait NamespaceRepo: Send + Sync { + /// Creates the namespace in the catalog. If one by the same name already exists, an + /// error is returned. + /// Specify `None` for `retention_period_ns` to get infinite retention. + async fn create( + &mut self, + name: &NamespaceName<'_>, + partition_template: Option, + retention_period_ns: Option, + service_protection_limits: Option, + ) -> Result; + + /// Update retention period for a namespace + async fn update_retention_period( + &mut self, + name: &str, + retention_period_ns: Option, + ) -> Result; + + /// List all namespaces. + async fn list(&mut self, deleted: SoftDeletedRows) -> Result>; + + /// Gets the namespace by its ID. + async fn get_by_id( + &mut self, + id: NamespaceId, + deleted: SoftDeletedRows, + ) -> Result>; + + /// Gets the namespace by its unique name. + async fn get_by_name( + &mut self, + name: &str, + deleted: SoftDeletedRows, + ) -> Result>; + + /// Soft-delete a namespace by name + async fn soft_delete(&mut self, name: &str) -> Result<()>; + + /// Update the limit on the number of tables that can exist per namespace. + async fn update_table_limit(&mut self, name: &str, new_max: MaxTables) -> Result; + + /// Update the limit on the number of columns that can exist per table in a given namespace. + async fn update_column_limit( + &mut self, + name: &str, + new_max: MaxColumnsPerTable, + ) -> Result; +} + +/// Functions for working with tables in the catalog +#[async_trait] +pub trait TableRepo: Send + Sync { + /// Creates the table in the catalog. If one in the same namespace with the same name already + /// exists, an error is returned. + async fn create( + &mut self, + name: &str, + partition_template: TablePartitionTemplateOverride, + namespace_id: NamespaceId, + ) -> Result
; + + /// get table by ID + async fn get_by_id(&mut self, table_id: TableId) -> Result>; + + /// get table by namespace ID and name + async fn get_by_namespace_and_name( + &mut self, + namespace_id: NamespaceId, + name: &str, + ) -> Result>; + + /// Lists all tables in the catalog for the given namespace id. + async fn list_by_namespace_id(&mut self, namespace_id: NamespaceId) -> Result>; + + /// List all tables. + async fn list(&mut self) -> Result>; + + /// Obtain a table snapshot + async fn snapshot(&mut self, table_id: TableId) -> Result; +} + +/// Functions for working with columns in the catalog +#[async_trait] +pub trait ColumnRepo: Send + Sync { + /// Creates the column in the catalog or returns the existing column. Will return a + /// `Error::ColumnTypeMismatch` if the existing column type doesn't match the type + /// the caller is attempting to create. + async fn create_or_get( + &mut self, + name: &str, + table_id: TableId, + column_type: ColumnType, + ) -> Result; + + /// Perform a bulk upsert of columns specified by a map of column name to column type. + /// + /// Implementations make no guarantees as to the ordering or atomicity of + /// the batch of column upsert operations - a batch upsert may partially + /// commit, in which case an error MUST be returned by the implementation. + /// + /// Per-namespace limits on the number of columns allowed per table are explicitly NOT checked + /// by this function, hence the name containing `unchecked`. It is expected that the caller + /// will check this first-- and yes, this is racy. + async fn create_or_get_many_unchecked( + &mut self, + table_id: TableId, + columns: HashMap<&str, ColumnType>, + ) -> Result>; + + /// Lists all columns in the passed in namespace id. + async fn list_by_namespace_id(&mut self, namespace_id: NamespaceId) -> Result>; + + /// List all columns for the given table ID. + async fn list_by_table_id(&mut self, table_id: TableId) -> Result>; + + /// List all columns. + async fn list(&mut self) -> Result>; +} + +/// Extension trait for [`ParquetFileRepo`] +#[async_trait] +pub trait PartitionRepoExt { + /// create the parquet file + async fn get_by_id(self, partition_id: PartitionId) -> Result>; +} + +#[async_trait] +impl PartitionRepoExt for &mut dyn PartitionRepo { + async fn get_by_id(self, partition_id: PartitionId) -> Result> { + let iter = self.get_by_id_batch(&[partition_id]).await?; + Ok(iter.into_iter().next()) + } +} + +/// Functions for working with IOx partitions in the catalog. These are how IOx splits up +/// data within a namespace. +#[async_trait] +pub trait PartitionRepo: Send + Sync { + /// create or get a partition record for the given partition key and table + async fn create_or_get(&mut self, key: PartitionKey, table_id: TableId) -> Result; + + /// get multiple partitions by ID. + /// + /// the output order is undefined, non-existing partitions are not part of the output. + async fn get_by_id_batch(&mut self, partition_ids: &[PartitionId]) -> Result>; + + /// return the partitions by table id + async fn list_by_table_id(&mut self, table_id: TableId) -> Result>; + + /// return all partitions IDs + async fn list_ids(&mut self) -> Result>; + + /// Update the sort key for the partition, setting it to `new_sort_key_ids` iff + /// the current value matches `old_sort_key_ids`. + /// + /// NOTE: it is expected that ONLY the ingesters update sort keys for + /// existing partitions. + /// + /// # Spurious failure + /// + /// Implementations are allowed to spuriously return + /// [`CasFailure::ValueMismatch`] for performance reasons in the presence of + /// concurrent writers. + async fn cas_sort_key( + &mut self, + partition_id: PartitionId, + old_sort_key_ids: Option<&SortKeyIds>, + new_sort_key_ids: &SortKeyIds, + ) -> Result>; + + /// Record an instance of a partition being selected for compaction but compaction was not + /// completed for the specified reason. + #[allow(clippy::too_many_arguments)] + async fn record_skipped_compaction( + &mut self, + partition_id: PartitionId, + reason: &str, + num_files: usize, + limit_num_files: usize, + limit_num_files_first_in_partition: usize, + estimated_bytes: u64, + limit_bytes: u64, + ) -> Result<()>; + + /// Get the record of partitions being skipped. + async fn get_in_skipped_compactions( + &mut self, + partition_id: &[PartitionId], + ) -> Result>; + + /// List the records of compacting a partition being skipped. This is mostly useful for testing. + async fn list_skipped_compactions(&mut self) -> Result>; + + /// Delete the records of skipping a partition being compacted. + async fn delete_skipped_compactions( + &mut self, + partition_id: PartitionId, + ) -> Result>; + + /// Return the N most recently created partitions. + async fn most_recent_n(&mut self, n: usize) -> Result>; + + /// Select partitions with a `new_file_at` value greater than the minimum time value and, if specified, less than + /// the maximum time value. Both range ends are exclusive; a timestamp exactly equal to either end will _not_ be + /// included in the results. + async fn partitions_new_file_between( + &mut self, + minimum_time: Timestamp, + maximum_time: Option, + ) -> Result>; + + /// Return all partitions that do not have deterministic hash IDs in the catalog. Used in + /// the ingester's `OldPartitionBloomFilter` to determine whether a catalog query is necessary. + /// Can be removed when all partitions have hash IDs and support for old-style partitions is no + /// longer needed. + async fn list_old_style(&mut self) -> Result>; + + /// Obtain a partition snapshot + async fn snapshot(&mut self, partition_id: PartitionId) -> Result; +} + +/// Extension trait for [`ParquetFileRepo`] +#[async_trait] +pub trait ParquetFileRepoExt { + /// create the parquet file + async fn create(self, parquet_file_params: ParquetFileParams) -> Result; +} + +#[async_trait] +impl ParquetFileRepoExt for &mut dyn ParquetFileRepo { + /// create the parquet file + async fn create(self, params: ParquetFileParams) -> Result { + let files = self + .create_upgrade_delete( + params.partition_id, + &[], + &[], + &[params.clone()], + CompactionLevel::Initial, + ) + .await?; + let id = files.into_iter().next().unwrap(); + Ok(ParquetFile::from_params(params, id)) + } +} + +/// Functions for working with parquet file pointers in the catalog +#[async_trait] +pub trait ParquetFileRepo: Send + Sync { + /// Flag all parquet files for deletion that are older than their namespace's retention period. + async fn flag_for_delete_by_retention(&mut self) -> Result>; + + /// Delete parquet files that were marked to be deleted earlier than the specified time. + /// + /// Returns the deleted IDs only. + /// + /// This deletion is limited to a certain (backend-specific) number of files to avoid overlarge + /// changes. The caller MAY call this method again if the result was NOT empty. + async fn delete_old_ids_only(&mut self, older_than: Timestamp) -> Result>; + + /// List parquet files for given partitions that are NOT marked as + /// [`to_delete`](ParquetFile::to_delete). + /// + /// The output order is undefined, non-existing partitions are not part of the output. + async fn list_by_partition_not_to_delete_batch( + &mut self, + partition_ids: Vec, + ) -> Result>; + + /// Return the parquet file with the given object store id + // used heavily in tests for verification of catalog state. + async fn get_by_object_store_id( + &mut self, + object_store_id: ObjectStoreId, + ) -> Result>; + + /// Test a batch of parquet files exist by object store ids + async fn exists_by_object_store_id_batch( + &mut self, + object_store_ids: Vec, + ) -> Result>; + + /// Commit deletions, upgrades and creations in a single transaction. + /// + /// Returns IDs of created files. + async fn create_upgrade_delete( + &mut self, + partition_id: PartitionId, + delete: &[ObjectStoreId], + upgrade: &[ObjectStoreId], + create: &[ParquetFileParams], + target_level: CompactionLevel, + ) -> Result>; +} diff --git a/iox_catalog/src/interface_tests.rs b/iox_catalog/src/interface_tests.rs new file mode 100644 index 0000000..4635483 --- /dev/null +++ b/iox_catalog/src/interface_tests.rs @@ -0,0 +1,3203 @@ +//! Abstract tests of the catalog interface w/o relying on the actual implementation. +use crate::{ + interface::{ + CasFailure, Catalog, Error, ParquetFileRepoExt, PartitionRepoExt, RepoCollection, + SoftDeletedRows, + }, + test_helpers::{arbitrary_namespace, arbitrary_parquet_file_params, arbitrary_table}, + util::{list_schemas, validate_or_insert_schema}, +}; + +use ::test_helpers::assert_error; +use assert_matches::assert_matches; +use async_trait::async_trait; +use data_types::snapshot::table::TableSnapshot; +use data_types::{ + partition_template::{NamespacePartitionTemplateOverride, TablePartitionTemplateOverride}, + ColumnId, ColumnType, CompactionLevel, MaxColumnsPerTable, MaxTables, Namespace, NamespaceId, + NamespaceName, NamespaceSchema, ObjectStoreId, ParquetFile, ParquetFileId, ParquetFileParams, + PartitionId, SortKeyIds, TableId, Timestamp, +}; +use data_types::{snapshot::partition::PartitionSnapshot, Column, PartitionHashId, PartitionKey}; +use futures::{Future, StreamExt}; +use generated_types::influxdata::iox::partition_template::v1 as proto; +use iox_time::TimeProvider; +use metric::{Attributes, DurationHistogram, Metric}; +use parking_lot::Mutex; +use std::{any::Any, fmt::Display}; +use std::{ + collections::{BTreeMap, BTreeSet, HashMap}, + ops::DerefMut, + sync::Arc, + time::Duration, +}; + +pub(crate) async fn test_catalog(clean_state: R) +where + R: Fn() -> F + Send + Sync, + F: Future> + Send, +{ + test_setup(clean_state().await).await; + test_namespace_soft_deletion(clean_state().await).await; + test_partitions_new_file_between(clean_state().await).await; + test_column(clean_state().await).await; + test_partition(clean_state().await).await; + test_parquet_file(clean_state().await).await; + test_parquet_file_delete_broken(clean_state().await).await; + test_update_to_compaction_level_1(clean_state().await).await; + test_list_by_partiton_not_to_delete(clean_state().await).await; + test_list_schemas(clean_state().await).await; + test_list_schemas_soft_deleted_rows(clean_state().await).await; + test_delete_namespace(clean_state().await).await; + + let catalog = clean_state().await; + test_namespace(Arc::clone(&catalog)).await; + assert_metric_hit(&catalog.metrics(), "namespace_create"); + + let catalog = clean_state().await; + test_table(Arc::clone(&catalog)).await; + assert_metric_hit(&catalog.metrics(), "table_create"); + + let catalog = clean_state().await; + test_column(Arc::clone(&catalog)).await; + assert_metric_hit(&catalog.metrics(), "column_create_or_get"); + + let catalog = clean_state().await; + test_partition(Arc::clone(&catalog)).await; + assert_metric_hit(&catalog.metrics(), "partition_create_or_get"); + + let catalog = clean_state().await; + test_parquet_file(Arc::clone(&catalog)).await; + assert_metric_hit(&catalog.metrics(), "parquet_create_upgrade_delete"); + + test_two_repos(clean_state().await).await; + test_partition_create_or_get_idempotent(clean_state().await).await; + test_column_create_or_get_many_unchecked(clean_state).await; +} + +async fn test_setup(catalog: Arc) { + catalog.setup().await.expect("first catalog setup"); + catalog.setup().await.expect("second catalog setup"); +} + +async fn test_namespace(catalog: Arc) { + let mut repos = catalog.repositories(); + let namespace_name = NamespaceName::new("test_namespace").unwrap(); + let namespace = repos + .namespaces() + .create(&namespace_name, None, None, None) + .await + .unwrap(); + assert!(namespace.id > NamespaceId::new(0)); + assert_eq!(namespace.name, namespace_name.as_str()); + assert_eq!( + namespace.partition_template, + NamespacePartitionTemplateOverride::default() + ); + let lookup_namespace = repos + .namespaces() + .get_by_name(&namespace_name, SoftDeletedRows::ExcludeDeleted) + .await + .unwrap() + .unwrap(); + assert_eq!(namespace, lookup_namespace); + + // Assert default values for service protection limits. + assert_eq!(namespace.max_tables, MaxTables::default()); + assert_eq!( + namespace.max_columns_per_table, + MaxColumnsPerTable::default() + ); + + let conflict = repos + .namespaces() + .create(&namespace_name, None, None, None) + .await; + assert!(matches!(conflict.unwrap_err(), Error::AlreadyExists { .. })); + + let found = repos + .namespaces() + .get_by_id(namespace.id, SoftDeletedRows::ExcludeDeleted) + .await + .unwrap() + .expect("namespace should be there"); + assert_eq!(namespace, found); + + let not_found = repos + .namespaces() + .get_by_id(NamespaceId::new(i64::MAX), SoftDeletedRows::ExcludeDeleted) + .await + .unwrap(); + assert!(not_found.is_none()); + + let found = repos + .namespaces() + .get_by_name(&namespace_name, SoftDeletedRows::ExcludeDeleted) + .await + .unwrap() + .expect("namespace should be there"); + assert_eq!(namespace, found); + + let not_found = repos + .namespaces() + .get_by_name("does_not_exist", SoftDeletedRows::ExcludeDeleted) + .await + .unwrap(); + assert!(not_found.is_none()); + + let namespace2 = arbitrary_namespace(&mut *repos, "test_namespace2").await; + let mut namespaces = repos + .namespaces() + .list(SoftDeletedRows::ExcludeDeleted) + .await + .unwrap(); + namespaces.sort_by_key(|ns| ns.name.clone()); + assert_eq!(namespaces, vec![namespace, namespace2]); + + let new_table_limit = MaxTables::try_from(15_000).unwrap(); + let modified = repos + .namespaces() + .update_table_limit(namespace_name.as_str(), new_table_limit) + .await + .expect("namespace should be updateable"); + assert_eq!(new_table_limit, modified.max_tables); + + let new_column_limit = MaxColumnsPerTable::try_from(1_500).unwrap(); + let modified = repos + .namespaces() + .update_column_limit(namespace_name.as_str(), new_column_limit) + .await + .expect("namespace should be updateable"); + assert_eq!(new_column_limit, modified.max_columns_per_table); + + const NEW_RETENTION_PERIOD_NS: i64 = 5 * 60 * 60 * 1000 * 1000 * 1000; + let modified = repos + .namespaces() + .update_retention_period(namespace_name.as_str(), Some(NEW_RETENTION_PERIOD_NS)) + .await + .expect("namespace should be updateable"); + assert_eq!( + NEW_RETENTION_PERIOD_NS, + modified.retention_period_ns.unwrap() + ); + + let modified = repos + .namespaces() + .update_retention_period(namespace_name.as_str(), None) + .await + .expect("namespace should be updateable"); + assert!(modified.retention_period_ns.is_none()); + + // create namespace with retention period NULL (the default) + let namespace3 = arbitrary_namespace(&mut *repos, "test_namespace3").await; + assert!(namespace3.retention_period_ns.is_none()); + + // create namespace with retention period + let namespace4_name = NamespaceName::new("test_namespace4").unwrap(); + let namespace4 = repos + .namespaces() + .create(&namespace4_name, None, Some(NEW_RETENTION_PERIOD_NS), None) + .await + .expect("namespace with 5-hour retention should be created"); + assert_eq!( + NEW_RETENTION_PERIOD_NS, + namespace4.retention_period_ns.unwrap() + ); + // reset retention period to NULL to avoid affecting later tests + repos + .namespaces() + .update_retention_period(&namespace4_name, None) + .await + .expect("namespace should be updateable"); + + // create a namespace with a PartitionTemplate other than the default + let tag_partition_template = + NamespacePartitionTemplateOverride::try_from(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("tag1".into())), + }], + }) + .unwrap(); + let namespace5_name = NamespaceName::new("test_namespace5").unwrap(); + let namespace5 = repos + .namespaces() + .create( + &namespace5_name, + Some(tag_partition_template.clone()), + None, + None, + ) + .await + .unwrap(); + assert_eq!(namespace5.partition_template, tag_partition_template); + let lookup_namespace5 = repos + .namespaces() + .get_by_name(&namespace5_name, SoftDeletedRows::ExcludeDeleted) + .await + .unwrap() + .unwrap(); + assert_eq!(namespace5, lookup_namespace5); + + // remove namespace to avoid it from affecting later tests + repos + .namespaces() + .soft_delete("test_namespace") + .await + .expect("delete namespace should succeed"); + repos + .namespaces() + .soft_delete("test_namespace2") + .await + .expect("delete namespace should succeed"); + repos + .namespaces() + .soft_delete("test_namespace3") + .await + .expect("delete namespace should succeed"); + repos + .namespaces() + .soft_delete("test_namespace4") + .await + .expect("delete namespace should succeed"); +} + +/// Construct a set of two namespaces: +/// +/// * deleted-ns: marked as soft-deleted +/// * active-ns: not marked as deleted +/// +/// And assert the expected "soft delete" semantics / correctly filter out +/// the expected rows for all three states of [`SoftDeletedRows`]. +async fn test_namespace_soft_deletion(catalog: Arc) { + let mut repos = catalog.repositories(); + + let deleted_ns = arbitrary_namespace(&mut *repos, "deleted-ns").await; + let active_ns = arbitrary_namespace(&mut *repos, "active-ns").await; + + // Mark "deleted-ns" as soft-deleted. + repos.namespaces().soft_delete("deleted-ns").await.unwrap(); + + // Which should be idempotent (ignoring the timestamp change - when + // changing this to "soft delete" it was idempotent, so I am preserving + // that). + repos.namespaces().soft_delete("deleted-ns").await.unwrap(); + + // Listing should respect soft deletion. + let got = repos + .namespaces() + .list(SoftDeletedRows::AllRows) + .await + .unwrap() + .into_iter() + .map(|v| v.name); + assert_string_set_eq(got, ["deleted-ns", "active-ns"]); + + let got = repos + .namespaces() + .list(SoftDeletedRows::OnlyDeleted) + .await + .unwrap() + .into_iter() + .map(|v| v.name); + assert_string_set_eq(got, ["deleted-ns"]); + + let got = repos + .namespaces() + .list(SoftDeletedRows::ExcludeDeleted) + .await + .unwrap() + .into_iter() + .map(|v| v.name); + assert_string_set_eq(got, ["active-ns"]); + + // As should get by ID + let got = repos + .namespaces() + .get_by_id(deleted_ns.id, SoftDeletedRows::AllRows) + .await + .unwrap() + .into_iter() + .map(|v| v.name); + assert_string_set_eq(got, ["deleted-ns"]); + let got = repos + .namespaces() + .get_by_id(deleted_ns.id, SoftDeletedRows::OnlyDeleted) + .await + .unwrap() + .into_iter() + .map(|v| { + assert!(v.deleted_at.is_some()); + v.name + }); + assert_string_set_eq(got, ["deleted-ns"]); + let got = repos + .namespaces() + .get_by_id(deleted_ns.id, SoftDeletedRows::ExcludeDeleted) + .await + .unwrap(); + assert!(got.is_none()); + let got = repos + .namespaces() + .get_by_id(active_ns.id, SoftDeletedRows::AllRows) + .await + .unwrap() + .into_iter() + .map(|v| v.name); + assert_string_set_eq(got, ["active-ns"]); + let got = repos + .namespaces() + .get_by_id(active_ns.id, SoftDeletedRows::OnlyDeleted) + .await + .unwrap(); + assert!(got.is_none()); + let got = repos + .namespaces() + .get_by_id(active_ns.id, SoftDeletedRows::ExcludeDeleted) + .await + .unwrap() + .into_iter() + .map(|v| v.name); + assert_string_set_eq(got, ["active-ns"]); + + // And get by name + let got = repos + .namespaces() + .get_by_name(&deleted_ns.name, SoftDeletedRows::AllRows) + .await + .unwrap() + .into_iter() + .map(|v| v.name); + assert_string_set_eq(got, ["deleted-ns"]); + let got = repos + .namespaces() + .get_by_name(&deleted_ns.name, SoftDeletedRows::OnlyDeleted) + .await + .unwrap() + .into_iter() + .map(|v| { + assert!(v.deleted_at.is_some()); + v.name + }); + assert_string_set_eq(got, ["deleted-ns"]); + let got = repos + .namespaces() + .get_by_name(&deleted_ns.name, SoftDeletedRows::ExcludeDeleted) + .await + .unwrap(); + assert!(got.is_none()); + let got = repos + .namespaces() + .get_by_name(&active_ns.name, SoftDeletedRows::AllRows) + .await + .unwrap() + .into_iter() + .map(|v| v.name); + assert_string_set_eq(got, ["active-ns"]); + let got = repos + .namespaces() + .get_by_name(&active_ns.name, SoftDeletedRows::OnlyDeleted) + .await + .unwrap(); + assert!(got.is_none()); + let got = repos + .namespaces() + .get_by_name(&active_ns.name, SoftDeletedRows::ExcludeDeleted) + .await + .unwrap() + .into_iter() + .map(|v| v.name); + assert_string_set_eq(got, ["active-ns"]); +} + +// Assert the set of strings "a" is equal to the set "b", tolerating +// duplicates. +#[track_caller] +fn assert_string_set_eq(a: impl IntoIterator, b: impl IntoIterator) +where + T: Into, + U: Into, +{ + let mut a = a.into_iter().map(Into::into).collect::>(); + a.sort_unstable(); + let mut b = b.into_iter().map(Into::into).collect::>(); + b.sort_unstable(); + assert_eq!(a, b); +} + +async fn test_table(catalog: Arc) { + let mut repos = catalog.repositories(); + let namespace = arbitrary_namespace(&mut *repos, "namespace_table_test").await; + + // test we can create a table + let t = arbitrary_table(&mut *repos, "test_table", &namespace).await; + assert!(t.id > TableId::new(0)); + assert_eq!( + t.partition_template, + TablePartitionTemplateOverride::default() + ); + + // The default template doesn't use any tag values, so no columns need to be created. + let table_columns = repos.columns().list_by_table_id(t.id).await.unwrap(); + assert!(table_columns.is_empty()); + + // test we get an error if we try to create it again + let err = repos + .tables() + .create( + "test_table", + TablePartitionTemplateOverride::try_new(None, &namespace.partition_template).unwrap(), + namespace.id, + ) + .await; + assert_error!( + err, + Error::AlreadyExists { ref descr } + if descr == &format!("table 'test_table' in namespace {}", namespace.id) + ); + + // get by id + assert_eq!(t, repos.tables().get_by_id(t.id).await.unwrap().unwrap()); + assert!(repos + .tables() + .get_by_id(TableId::new(i64::MAX)) + .await + .unwrap() + .is_none()); + + let tables = repos + .tables() + .list_by_namespace_id(namespace.id) + .await + .unwrap(); + assert_eq!(vec![t.clone()], tables); + + // test we can create a table of the same name in a different namespace + let namespace2 = arbitrary_namespace(&mut *repos, "two").await; + assert_ne!(namespace, namespace2); + let test_table = arbitrary_table(&mut *repos, "test_table", &namespace2).await; + assert_ne!(t.id, test_table.id); + assert_eq!(test_table.namespace_id, namespace2.id); + + // test get by namespace and name + let foo_table = arbitrary_table(&mut *repos, "foo", &namespace2).await; + assert_eq!( + repos + .tables() + .get_by_namespace_and_name(NamespaceId::new(i64::MAX), "test_table") + .await + .unwrap(), + None + ); + assert_eq!( + repos + .tables() + .get_by_namespace_and_name(namespace.id, "not_existing") + .await + .unwrap(), + None + ); + assert_eq!( + repos + .tables() + .get_by_namespace_and_name(namespace.id, "test_table") + .await + .unwrap(), + Some(t.clone()) + ); + assert_eq!( + repos + .tables() + .get_by_namespace_and_name(namespace2.id, "test_table") + .await + .unwrap() + .as_ref(), + Some(&test_table) + ); + assert_eq!( + repos + .tables() + .get_by_namespace_and_name(namespace2.id, "foo") + .await + .unwrap() + .as_ref(), + Some(&foo_table) + ); + + // All tables should be returned by list(), regardless of namespace + let mut list = repos.tables().list().await.unwrap(); + list.sort_by_key(|t| t.id); + let mut expected = [t, test_table, foo_table]; + expected.sort_by_key(|t| t.id); + assert_eq!(&list, &expected); + + // test per-namespace table limits + let latest = repos + .namespaces() + .update_table_limit("namespace_table_test", MaxTables::try_from(1).unwrap()) + .await + .expect("namespace should be updateable"); + let err = repos + .tables() + .create( + "definitely_unique", + TablePartitionTemplateOverride::try_new(None, &latest.partition_template).unwrap(), + latest.id, + ) + .await + .expect_err("should error with table create limit error"); + assert!(matches!(err, Error::LimitExceeded { .. })); + + // Create a table with a partition template other than the default + let custom_table_template = TablePartitionTemplateOverride::try_new( + Some(proto::PartitionTemplate { + parts: vec![ + proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("tag1".into())), + }, + proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat("year-%Y".into())), + }, + proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("tag2".into())), + }, + ], + }), + &namespace2.partition_template, + ) + .unwrap(); + let templated = repos + .tables() + .create( + "use_a_template", + custom_table_template.clone(), + namespace2.id, + ) + .await + .unwrap(); + assert_eq!(templated.partition_template, custom_table_template); + + // Tag columns should be created for tags used in the template + let table_columns = repos + .columns() + .list_by_table_id(templated.id) + .await + .unwrap(); + assert_eq!(table_columns.len(), 2); + assert!(table_columns.iter().all(|c| c.is_tag())); + let mut column_names: Vec<_> = table_columns.iter().map(|c| &c.name).collect(); + column_names.sort(); + assert_eq!(column_names, &["tag1", "tag2"]); + + let lookup_templated = repos + .tables() + .get_by_namespace_and_name(namespace2.id, "use_a_template") + .await + .unwrap() + .unwrap(); + assert_eq!(templated, lookup_templated); + + // Create a namespace with a partition template other than the default + let custom_namespace_template = + NamespacePartitionTemplateOverride::try_from(proto::PartitionTemplate { + parts: vec![ + proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("zzz".into())), + }, + proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("aaa".into())), + }, + proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat("year-%Y".into())), + }, + ], + }) + .unwrap(); + let custom_namespace_name = NamespaceName::new("custom_namespace").unwrap(); + let custom_namespace = repos + .namespaces() + .create( + &custom_namespace_name, + Some(custom_namespace_template.clone()), + None, + None, + ) + .await + .unwrap(); + // Create a table without specifying the partition template + let custom_table_template = + TablePartitionTemplateOverride::try_new(None, &custom_namespace.partition_template) + .unwrap(); + let table_templated_by_namespace = repos + .tables() + .create( + "use_namespace_template", + custom_table_template, + custom_namespace.id, + ) + .await + .unwrap(); + assert_eq!( + table_templated_by_namespace.partition_template, + TablePartitionTemplateOverride::try_new(None, &custom_namespace_template).unwrap() + ); + + // Tag columns should be created for tags used in the template + let table_columns = repos + .columns() + .list_by_table_id(table_templated_by_namespace.id) + .await + .unwrap(); + assert_eq!(table_columns.len(), 2); + assert!(table_columns.iter().all(|c| c.is_tag())); + let mut column_names: Vec<_> = table_columns.iter().map(|c| &c.name).collect(); + column_names.sort(); + assert_eq!(column_names, &["aaa", "zzz"]); + + repos + .namespaces() + .soft_delete("namespace_table_test") + .await + .expect("delete namespace should succeed"); + repos + .namespaces() + .soft_delete("two") + .await + .expect("delete namespace should succeed"); +} + +async fn test_column(catalog: Arc) { + let mut repos = catalog.repositories(); + let namespace = arbitrary_namespace(&mut *repos, "namespace_column_test").await; + let table = arbitrary_table(&mut *repos, "test_table", &namespace).await; + assert_eq!(table.namespace_id, namespace.id); + + // test we can create or get a column + let c = repos + .columns() + .create_or_get("column_test", table.id, ColumnType::Tag) + .await + .unwrap(); + + let ts1 = repos.tables().snapshot(table.id).await.unwrap(); + validate_table_snapshot(repos.as_mut(), &ts1).await; + + let cc = repos + .columns() + .create_or_get("column_test", table.id, ColumnType::Tag) + .await + .unwrap(); + assert!(c.id > ColumnId::new(0)); + assert_eq!(c, cc); + + let ts2 = repos.tables().snapshot(table.id).await.unwrap(); + validate_table_snapshot(repos.as_mut(), &ts2).await; + + assert_gt(ts2.generation(), ts1.generation()); + + // test that attempting to create an already defined column of a different type returns + // error + let err = repos + .columns() + .create_or_get("column_test", table.id, ColumnType::U64) + .await + .expect_err("should error with wrong column type"); + assert!(matches!(err, Error::AlreadyExists { .. })); + + // test that we can create a column of the same name under a different table + let table2 = arbitrary_table(&mut *repos, "test_table_2", &namespace).await; + let ccc = repos + .columns() + .create_or_get("column_test", table2.id, ColumnType::U64) + .await + .unwrap(); + assert_ne!(c, ccc); + + let columns = repos + .columns() + .list_by_namespace_id(namespace.id) + .await + .unwrap(); + + let ts3 = repos.tables().snapshot(table2.id).await.unwrap(); + validate_table_snapshot(repos.as_mut(), &ts3).await; + + let mut want = vec![c.clone(), ccc]; + assert_eq!(want, columns); + + let columns = repos.columns().list_by_table_id(table.id).await.unwrap(); + + let want2 = vec![c]; + assert_eq!(want2, columns); + + // Add another tag column into table2 + let c3 = repos + .columns() + .create_or_get("b", table2.id, ColumnType::Tag) + .await + .unwrap(); + + let ts4 = repos.tables().snapshot(table2.id).await.unwrap(); + validate_table_snapshot(repos.as_mut(), &ts4).await; + + assert_gt(ts4.generation(), ts3.generation()); + + // Listing columns should return all columns in the catalog + let list = repos.columns().list().await.unwrap(); + want.extend([c3]); + assert_eq!(list, want); + + // test create_or_get_many_unchecked, below column limit + let mut columns = HashMap::new(); + columns.insert("column_test", ColumnType::Tag); + columns.insert("new_column", ColumnType::Tag); + let table1_columns = repos + .columns() + .create_or_get_many_unchecked(table.id, columns) + .await + .unwrap(); + let mut table1_column_names: Vec<_> = table1_columns.iter().map(|c| &c.name).collect(); + table1_column_names.sort(); + assert_eq!(table1_column_names, vec!["column_test", "new_column"]); + + // test per-namespace column limits + repos + .namespaces() + .update_column_limit( + "namespace_column_test", + MaxColumnsPerTable::try_from(1).unwrap(), + ) + .await + .expect("namespace should be updateable"); + let err = repos + .columns() + .create_or_get("definitely unique", table.id, ColumnType::Tag) + .await + .expect_err("should error with table create limit error"); + assert!(matches!(err, Error::LimitExceeded { .. })); + + // test per-namespace column limits are NOT enforced with create_or_get_many_unchecked + let table3 = arbitrary_table(&mut *repos, "test_table_3", &namespace).await; + let mut columns = HashMap::new(); + columns.insert("apples", ColumnType::Tag); + columns.insert("oranges", ColumnType::Tag); + let table3_columns = repos + .columns() + .create_or_get_many_unchecked(table3.id, columns) + .await + .unwrap(); + let mut table3_column_names: Vec<_> = table3_columns.iter().map(|c| &c.name).collect(); + table3_column_names.sort(); + assert_eq!(table3_column_names, vec!["apples", "oranges"]); + + repos + .namespaces() + .soft_delete("namespace_column_test") + .await + .expect("delete namespace should succeed"); +} + +async fn test_partition(catalog: Arc) { + let mut repos = catalog.repositories(); + let namespace = arbitrary_namespace(&mut *repos, "namespace_partition_test").await; + let table = arbitrary_table(&mut *repos, "test_table", &namespace).await; + + let mut created = BTreeMap::new(); + // partition to use + let partition = repos + .partitions() + .create_or_get("foo".into(), table.id) + .await + .expect("failed to create partition"); + // Test: sort_key_ids from create_or_get + assert!(partition.sort_key_ids().is_none()); + created.insert(partition.id, partition.clone()); + // partition to use + let partition_bar = repos + .partitions() + .create_or_get("bar".into(), table.id) + .await + .expect("failed to create partition"); + created.insert(partition_bar.id, partition_bar); + // partition to be skipped later + let to_skip_partition = repos + .partitions() + .create_or_get("asdf".into(), table.id) + .await + .unwrap(); + created.insert(to_skip_partition.id, to_skip_partition.clone()); + // partition to be skipped later + let to_skip_partition_too = repos + .partitions() + .create_or_get("asdf too".into(), table.id) + .await + .unwrap(); + created.insert(to_skip_partition_too.id, to_skip_partition_too.clone()); + + // partitions can be retrieved easily + let mut created_sorted = created.values().cloned().collect::>(); + created_sorted.sort_by_key(|p| p.id); + assert_eq!( + to_skip_partition, + repos + .partitions() + .get_by_id_batch(&[to_skip_partition.id]) + .await + .unwrap() + .into_iter() + .next() + .unwrap() + ); + let non_existing_partition_id = PartitionId::new(i64::MAX); + assert!(repos + .partitions() + .get_by_id_batch(&[non_existing_partition_id]) + .await + .unwrap() + .is_empty()); + let mut batch = repos + .partitions() + .get_by_id_batch( + &created + .keys() + .cloned() + // non-existing entries are ignored + .chain([non_existing_partition_id]) + // duplicates are ignored + .chain(created.keys().cloned()) + .collect::>(), + ) + .await + .unwrap(); + batch.sort_by_key(|p| p.id); + assert_eq!(created_sorted, batch); + // Test: sort_key_ids from get_by_id_batch + assert!(batch.iter().all(|p| p.sort_key_ids().is_none())); + + assert_eq!(created_sorted, batch); + + let s1 = repos.tables().snapshot(table.id).await.unwrap(); + validate_table_snapshot(repos.as_mut(), &s1).await; + + let listed = repos + .partitions() + .list_by_table_id(table.id) + .await + .expect("failed to list partitions") + .into_iter() + .map(|v| (v.id, v)) + .collect::>(); + // Test: sort_key_ids from list_by_table_id + assert!(listed.values().all(|p| p.sort_key_ids().is_none())); + + assert_eq!(created, listed); + + let listed = repos + .partitions() + .list_ids() + .await + .expect("failed to list partitions") + .into_iter() + .collect::>(); + + assert_eq!(created.keys().copied().collect::>(), listed); + + // The code no longer supports creating old-style partitions, so this list is always empty + // in these tests. See each catalog implementation for tests that insert old-style + // partitions directly and verify they're returned. + let old_style = repos.partitions().list_old_style().await.unwrap(); + assert!( + old_style.is_empty(), + "Expected no old-style partitions, got {old_style:?}" + ); + + // sort key should be unset on creation + assert!(to_skip_partition.sort_key_ids().is_none()); + + let s1 = repos + .partitions() + .snapshot(to_skip_partition.id) + .await + .unwrap(); + validate_partition_snapshot(repos.as_mut(), &s1).await; + + // test that updates sort key from None to Some + let updated_partition = repos + .partitions() + .cas_sort_key(to_skip_partition.id, None, &SortKeyIds::from([2, 1, 3])) + .await + .unwrap(); + + // verify sort key is updated correctly + assert_eq!( + updated_partition.sort_key_ids().unwrap(), + &SortKeyIds::from([2, 1, 3]) + ); + + let s2 = repos + .partitions() + .snapshot(to_skip_partition.id) + .await + .unwrap(); + assert_gt(s2.generation(), s1.generation()); + validate_partition_snapshot(repos.as_mut(), &s2).await; + + // test that provides value of old_sort_key_ids but it do not match the existing one + // --> the new sort key will not be updated + let err = repos + .partitions() + .cas_sort_key( + to_skip_partition.id, + Some(&SortKeyIds::from([1])), + &SortKeyIds::from([1, 2, 3, 4]), + ) + .await + .expect_err("CAS with incorrect value should fail"); + // verify the sort key is not updated + assert_matches!(err, CasFailure::ValueMismatch(old_sort_key_ids) => { + assert_eq!(old_sort_key_ids, SortKeyIds::from([2, 1, 3])); + }); + + // test that provides same length but not-matched old_sort_key_ids + // --> the new sort key will not be updated + let err = repos + .partitions() + .cas_sort_key( + to_skip_partition.id, + Some(&SortKeyIds::from([1, 5, 10])), + &SortKeyIds::from([1, 2, 3, 4]), + ) + .await + .expect_err("CAS with incorrect value should fail"); + // verify the sort key is not updated + assert_matches!(err, CasFailure::ValueMismatch(old_sort_key_ids) => { + assert_eq!(old_sort_key_ids, SortKeyIds::from([2, 1, 3])); + }); + + // test that provide None sort_key_ids that do not match with existing values that are not None + // --> the new sort key will not be updated + let err = repos + .partitions() + .cas_sort_key(to_skip_partition.id, None, &SortKeyIds::from([1, 2, 3, 4])) + .await + .expect_err("CAS with incorrect value should fail"); + assert_matches!(err, CasFailure::ValueMismatch(old_sort_key_ids) => { + assert_eq!(old_sort_key_ids, SortKeyIds::from([2, 1, 3])); + }); + + // test getting partition from partition id and verify values of sort_key and sort_key_ids + let updated_other_partition = repos + .partitions() + .get_by_id_batch(&[to_skip_partition.id]) + .await + .unwrap() + .into_iter() + .next() + .unwrap(); + // still has the old sort key + assert_eq!( + updated_other_partition.sort_key_ids().unwrap(), + &SortKeyIds::from([2, 1, 3]) + ); + + // test that updates sort_key_ids from Some matching value to Some other value + let updated_partition = repos + .partitions() + .cas_sort_key( + to_skip_partition.id, + Some(&SortKeyIds::from([2, 1, 3])), + &SortKeyIds::from([2, 1, 4, 3]), + ) + .await + .unwrap(); + // verify the new values are updated + assert_eq!( + updated_partition.sort_key_ids().unwrap(), + &SortKeyIds::from([2, 1, 4, 3]) + ); + + // test getting the new sort key from partition id + let updated_partition = repos + .partitions() + .get_by_id_batch(&[to_skip_partition.id]) + .await + .unwrap() + .into_iter() + .next() + .unwrap(); + assert_eq!( + updated_partition.sort_key_ids().unwrap(), + &SortKeyIds::from([2, 1, 4, 3]) + ); + + // use to_skip_partition_too to update sort key from empty old values + // first make sure the old sort key is unset + assert!(to_skip_partition_too.sort_key_ids().is_none()); + + // test that provides empty old_sort_key_ids + // --> the new sort key will be updated + let updated_to_skip_partition_too = repos + .partitions() + .cas_sort_key(to_skip_partition_too.id, None, &SortKeyIds::from([3, 4])) + .await + .unwrap(); + // verify the new values are updated + assert_eq!( + updated_to_skip_partition_too.sort_key_ids().unwrap(), + &SortKeyIds::from([3, 4]) + ); + + let s3 = repos + .partitions() + .snapshot(to_skip_partition.id) + .await + .unwrap(); + assert_gt(s3.generation(), s2.generation()); + validate_partition_snapshot(repos.as_mut(), &s3).await; + + // The compactor can log why compaction was skipped + let skipped_compactions = repos.partitions().list_skipped_compactions().await.unwrap(); + assert!( + skipped_compactions.is_empty(), + "Expected no skipped compactions, got: {skipped_compactions:?}" + ); + repos + .partitions() + .record_skipped_compaction(to_skip_partition.id, "I am le tired", 1, 2, 4, 10, 20) + .await + .unwrap(); + let skipped_compactions = repos.partitions().list_skipped_compactions().await.unwrap(); + assert_eq!(skipped_compactions.len(), 1); + assert_eq!(skipped_compactions[0].partition_id, to_skip_partition.id); + assert_eq!(skipped_compactions[0].reason, "I am le tired"); + assert_eq!(skipped_compactions[0].num_files, 1); + assert_eq!(skipped_compactions[0].limit_num_files, 2); + assert_eq!(skipped_compactions[0].estimated_bytes, 10); + assert_eq!(skipped_compactions[0].limit_bytes, 20); + // + let skipped_partition_records = repos + .partitions() + .get_in_skipped_compactions(&[ + to_skip_partition.id, + PartitionId::new(i64::MAX), + to_skip_partition.id, + ]) + .await + .unwrap(); + assert_eq!( + skipped_partition_records[0].partition_id, + to_skip_partition.id + ); + assert_eq!(skipped_partition_records[0].reason, "I am le tired"); + + let s4 = repos + .partitions() + .snapshot(to_skip_partition.id) + .await + .unwrap(); + assert_gt(s4.generation(), s3.generation()); + validate_partition_snapshot(repos.as_mut(), &s4).await; + + // Only save the last reason that any particular partition was skipped (really if the + // partition appears in the skipped compactions, it shouldn't become a compaction candidate + // again, but race conditions and all that) + repos + .partitions() + .record_skipped_compaction(to_skip_partition.id, "I'm on fire", 11, 12, 24, 110, 120) + .await + .unwrap(); + let skipped_compactions = repos.partitions().list_skipped_compactions().await.unwrap(); + assert_eq!(skipped_compactions.len(), 1); + assert_eq!(skipped_compactions[0].partition_id, to_skip_partition.id); + assert_eq!(skipped_compactions[0].reason, "I'm on fire"); + assert_eq!(skipped_compactions[0].num_files, 11); + assert_eq!(skipped_compactions[0].limit_num_files, 12); + assert_eq!(skipped_compactions[0].estimated_bytes, 110); + assert_eq!(skipped_compactions[0].limit_bytes, 120); + // + let skipped_partition_records = repos + .partitions() + .get_in_skipped_compactions(&[to_skip_partition.id]) + .await + .unwrap(); + assert_eq!( + skipped_partition_records[0].partition_id, + to_skip_partition.id + ); + assert_eq!(skipped_partition_records[0].reason, "I'm on fire"); + + // Can receive multiple skipped compactions for different partitions + repos + .partitions() + .record_skipped_compaction( + to_skip_partition_too.id, + "I am le tired too", + 1, + 2, + 4, + 10, + 20, + ) + .await + .unwrap(); + let skipped_compactions = repos.partitions().list_skipped_compactions().await.unwrap(); + assert_eq!(skipped_compactions.len(), 2); + assert_eq!(skipped_compactions[0].partition_id, to_skip_partition.id); + assert_eq!( + skipped_compactions[1].partition_id, + to_skip_partition_too.id + ); + // confirm can fetch subset of skipped compactions (a.k.a. have two, only fetch 1) + let skipped_partition_records = repos + .partitions() + .get_in_skipped_compactions(&[to_skip_partition.id]) + .await + .unwrap(); + assert_eq!(skipped_partition_records.len(), 1); + assert_eq!(skipped_compactions[0].partition_id, to_skip_partition.id); + let skipped_partition_records = repos + .partitions() + .get_in_skipped_compactions(&[to_skip_partition_too.id]) + .await + .unwrap(); + assert_eq!(skipped_partition_records.len(), 1); + assert_eq!( + skipped_partition_records[0].partition_id, + to_skip_partition_too.id + ); + // confirm can fetch both skipped compactions, and not the unskipped one + // also confirm will not error on non-existing partition + let non_existing_partition_id = PartitionId::new(9999); + let skipped_partition_records = repos + .partitions() + .get_in_skipped_compactions(&[ + partition.id, + to_skip_partition.id, + to_skip_partition_too.id, + non_existing_partition_id, + ]) + .await + .unwrap(); + assert_eq!(skipped_partition_records.len(), 2); + assert_eq!( + skipped_partition_records[0].partition_id, + to_skip_partition.id + ); + assert_eq!( + skipped_partition_records[1].partition_id, + to_skip_partition_too.id + ); + + // Delete the skipped compactions + let deleted_skipped_compaction = repos + .partitions() + .delete_skipped_compactions(to_skip_partition.id) + .await + .unwrap() + .expect("The skipped compaction should have been returned"); + assert_eq!( + deleted_skipped_compaction.partition_id, + to_skip_partition.id + ); + assert_eq!(deleted_skipped_compaction.reason, "I'm on fire"); + assert_eq!(deleted_skipped_compaction.num_files, 11); + assert_eq!(deleted_skipped_compaction.limit_num_files, 12); + assert_eq!(deleted_skipped_compaction.estimated_bytes, 110); + assert_eq!(deleted_skipped_compaction.limit_bytes, 120); + // + let deleted_skipped_compaction = repos + .partitions() + .delete_skipped_compactions(to_skip_partition_too.id) + .await + .unwrap() + .expect("The skipped compaction should have been returned"); + assert_eq!( + deleted_skipped_compaction.partition_id, + to_skip_partition_too.id + ); + assert_eq!(deleted_skipped_compaction.reason, "I am le tired too"); + // + let skipped_partition_records = repos + .partitions() + .get_in_skipped_compactions(&[to_skip_partition.id]) + .await + .unwrap(); + assert!(skipped_partition_records.is_empty()); + + let not_deleted_skipped_compaction = repos + .partitions() + .delete_skipped_compactions(to_skip_partition.id) + .await + .unwrap(); + + assert!( + not_deleted_skipped_compaction.is_none(), + "There should be no skipped compation", + ); + + let skipped_compactions = repos.partitions().list_skipped_compactions().await.unwrap(); + assert!( + skipped_compactions.is_empty(), + "Expected no skipped compactions, got: {skipped_compactions:?}" + ); + + let recent = repos + .partitions() + .most_recent_n(10) + .await + .expect("should list most recent"); + assert_eq!(recent.len(), 4); + + // Test: sort_key_ids from most_recent_n + // Only the first two partitions (represent to_skip_partition_too and to_skip_partition) have vallues, the others are empty + assert_eq!( + recent[0].sort_key_ids().unwrap(), + &SortKeyIds::from(vec![3, 4]) + ); + assert_eq!( + recent[1].sort_key_ids().unwrap(), + &SortKeyIds::from(vec![2, 1, 4, 3]) + ); + assert!(recent[2].sort_key_ids().is_none()); + assert!(recent[3].sort_key_ids().is_none()); + + let recent = repos + .partitions() + .most_recent_n(4) + .await + .expect("should list most recent"); + assert_eq!(recent.len(), 4); // no off by one error + + let recent = repos + .partitions() + .most_recent_n(2) + .await + .expect("should list most recent"); + assert_eq!(recent.len(), 2); + + repos + .namespaces() + .soft_delete("namespace_partition_test") + .await + .expect("delete namespace should succeed"); +} + +async fn validate_partition_snapshot(repos: &mut dyn RepoCollection, snapshot: &PartitionSnapshot) { + // compare files + let mut expected = repos + .parquet_files() + .list_by_partition_not_to_delete_batch(vec![snapshot.partition_id()]) + .await + .unwrap(); + expected.sort_unstable_by_key(|x| x.id); + let mut actual = snapshot.files().collect::, _>>().unwrap(); + actual.sort_unstable_by_key(|x| x.id); + assert_eq!(expected, actual); + + // compare skipped partition + let expected = repos + .partitions() + .get_in_skipped_compactions(&[snapshot.partition_id()]) + .await + .unwrap() + .into_iter() + .next(); + let actual = snapshot.skipped_compaction(); + assert_eq!(actual, expected); + + // compare partition itself + let actual = snapshot.partition().unwrap(); + let expected = repos + .partitions() + .get_by_id(snapshot.partition_id()) + .await + .unwrap() + .unwrap(); + assert_eq!(actual, expected); +} + +async fn validate_table_snapshot(repos: &mut dyn RepoCollection, snapshot: &TableSnapshot) { + let table = snapshot.table().unwrap(); + + let expected = repos.tables().get_by_id(table.id).await.unwrap().unwrap(); + assert_eq!(table, expected); + + // compare columns + let mut expected = repos.columns().list_by_table_id(table.id).await.unwrap(); + expected.sort_unstable_by_key(|x| x.id); + let mut actual = snapshot.columns().collect::, _>>().unwrap(); + actual.sort_unstable_by_key(|x| x.id); + assert_eq!(expected, actual); + + // compare partitions + let mut expected = repos.partitions().list_by_table_id(table.id).await.unwrap(); + expected.sort_unstable_by_key(|x| x.id); + let mut actual = snapshot + .partitions() + .collect::, _>>() + .unwrap(); + actual.sort_unstable_by_key(|x| x.id()); + assert_eq!(expected.len(), actual.len()); + + let eq = expected + .iter() + .zip(&actual) + .all(|(l, r)| l.id == r.id() && l.partition_key.as_bytes() == r.key()); + assert!(eq, "expected {expected:?} got {actual:?}"); +} + +/// List all parquet files in given namespace. +async fn list_parquet_files_by_namespace_not_to_delete( + catalog: Arc, + namespace_id: NamespaceId, +) -> Vec { + let partitions = futures::stream::iter( + catalog + .repositories() + .tables() + .list_by_namespace_id(namespace_id) + .await + .unwrap(), + ) + .then(|t| { + let catalog = Arc::clone(&catalog); + async move { + futures::stream::iter( + catalog + .repositories() + .partitions() + .list_by_table_id(t.id) + .await + .unwrap(), + ) + } + }) + .flatten() + .map(|p| p.id) + .collect::>() + .await; + + catalog + .repositories() + .parquet_files() + .list_by_partition_not_to_delete_batch(partitions) + .await + .unwrap() +} + +/// tests many interactions with the catalog and parquet files. See the individual conditions +/// herein +async fn test_parquet_file(catalog: Arc) { + let mut repos = catalog.repositories(); + let namespace = arbitrary_namespace(&mut *repos, "namespace_parquet_file_test").await; + let table = arbitrary_table(&mut *repos, "test_table", &namespace).await; + let other_table = arbitrary_table(&mut *repos, "other", &namespace).await; + let partition = repos + .partitions() + .create_or_get("one".into(), table.id) + .await + .unwrap(); + let other_partition = repos + .partitions() + .create_or_get("one".into(), other_table.id) + .await + .unwrap(); + + let ts1 = repos.tables().snapshot(table.id).await.unwrap(); + validate_table_snapshot(repos.as_mut(), &ts1).await; + + let ts2 = repos.tables().snapshot(other_table.id).await.unwrap(); + validate_table_snapshot(repos.as_mut(), &ts2).await; + + let parquet_file_params = arbitrary_parquet_file_params(&namespace, &table, &partition); + let parquet_file = repos + .parquet_files() + .create(parquet_file_params.clone()) + .await + .unwrap(); + + // verify we can get it by its object store id + let pfg = repos + .parquet_files() + .get_by_object_store_id(parquet_file.object_store_id) + .await + .unwrap(); + assert_eq!(parquet_file, pfg.unwrap()); + + // verify that trying to create a file with the same UUID throws an error + let err = repos + .parquet_files() + .create(parquet_file_params.clone()) + .await + .unwrap_err(); + assert!(matches!(err, Error::AlreadyExists { .. })); + + let other_params = ParquetFileParams { + table_id: other_partition.table_id, + partition_id: other_partition.id, + partition_hash_id: other_partition.hash_id().cloned(), + object_store_id: ObjectStoreId::new(), + min_time: Timestamp::new(50), + max_time: Timestamp::new(60), + ..parquet_file_params.clone() + }; + let other_file = repos.parquet_files().create(other_params).await.unwrap(); + + let exist_id = parquet_file.id; + let non_exist_id = ParquetFileId::new(other_file.id.get() + 10); + // make sure exists_id != non_exist_id + assert_ne!(exist_id, non_exist_id); + + // verify that to_delete is initially set to null and the file does not get deleted + assert!(parquet_file.to_delete.is_none()); + let older_than = Timestamp::new( + (catalog.time_provider().now() + Duration::from_secs(100)).timestamp_nanos(), + ); + let deleted = repos + .parquet_files() + .delete_old_ids_only(older_than) + .await + .unwrap(); + assert!(deleted.is_empty()); + + // test list_all that includes soft-deleted file + // at this time the file is not soft-deleted yet and will be included in the returned list + let files = + list_parquet_files_by_namespace_not_to_delete(Arc::clone(&catalog), namespace.id).await; + assert_eq!(files.len(), 2); + + // verify to_delete can be updated to a timestamp + repos + .parquet_files() + .create_upgrade_delete( + parquet_file.partition_id, + &[parquet_file.object_store_id], + &[], + &[], + CompactionLevel::Initial, + ) + .await + .unwrap(); + + // test list_all that includes soft-deleted file + // at this time the file is soft-deleted and will be NOT included in the returned list + let files = + list_parquet_files_by_namespace_not_to_delete(Arc::clone(&catalog), namespace.id).await; + assert_eq!(files.len(), 1); + + // the deleted file can still be retrieved by UUID though + repos + .parquet_files() + .get_by_object_store_id(parquet_file.object_store_id) + .await + .unwrap() + .unwrap(); + + // File is not deleted if it was marked to be deleted after the specified time + let before_deleted = Timestamp::new( + (catalog.time_provider().now() - Duration::from_secs(100)).timestamp_nanos(), + ); + let deleted = repos + .parquet_files() + .delete_old_ids_only(before_deleted) + .await + .unwrap(); + assert!(deleted.is_empty()); + + // not hard-deleted yet + repos + .parquet_files() + .get_by_object_store_id(parquet_file.object_store_id) + .await + .unwrap() + .unwrap(); + + // File is deleted if it was marked to be deleted before the specified time + let deleted = repos + .parquet_files() + .delete_old_ids_only(older_than) + .await + .unwrap(); + assert_eq!(deleted.len(), 1); + assert_eq!(parquet_file.object_store_id, deleted[0]); + + // test list_all that includes soft-deleted file + // at this time the file is hard deleted -> the returned list is empty + assert!(repos + .parquet_files() + .get_by_object_store_id(parquet_file.object_store_id) + .await + .unwrap() + .is_none()); + + // test list + let files = + list_parquet_files_by_namespace_not_to_delete(Arc::clone(&catalog), namespace.id).await; + assert_eq!(vec![other_file.clone()], files); + + // test list_by_namespace_not_to_delete + let namespace2 = arbitrary_namespace(&mut *repos, "namespace_parquet_file_test1").await; + let table2 = arbitrary_table(&mut *repos, "test_table2", &namespace2).await; + let partition2 = repos + .partitions() + .create_or_get("foo".into(), table2.id) + .await + .unwrap(); + let files = + list_parquet_files_by_namespace_not_to_delete(Arc::clone(&catalog), namespace2.id).await; + assert!(files.is_empty()); + + let ts3 = repos.tables().snapshot(table2.id).await.unwrap(); + validate_table_snapshot(repos.as_mut(), &ts3).await; + + let f1_params = ParquetFileParams { + table_id: partition2.table_id, + partition_id: partition2.id, + partition_hash_id: partition2.hash_id().cloned(), + namespace_id: namespace2.id, + object_store_id: ObjectStoreId::new(), + min_time: Timestamp::new(1), + max_time: Timestamp::new(10), + ..parquet_file_params + }; + let f1 = repos + .parquet_files() + .create(f1_params.clone()) + .await + .unwrap(); + + let f2_params = ParquetFileParams { + object_store_id: ObjectStoreId::new(), + min_time: Timestamp::new(50), + max_time: Timestamp::new(60), + ..f1_params.clone() + }; + let f2 = repos + .parquet_files() + .create(f2_params.clone()) + .await + .unwrap(); + let files = + list_parquet_files_by_namespace_not_to_delete(Arc::clone(&catalog), namespace2.id).await; + assert_eq!(vec![f1.clone(), f2.clone()], files); + + let f3_params = ParquetFileParams { + object_store_id: ObjectStoreId::new(), + min_time: Timestamp::new(50), + max_time: Timestamp::new(60), + ..f2_params + }; + let f3 = repos + .parquet_files() + .create(f3_params.clone()) + .await + .unwrap(); + let files = + list_parquet_files_by_namespace_not_to_delete(Arc::clone(&catalog), namespace2.id).await; + assert_eq!(vec![f1.clone(), f2.clone(), f3.clone()], files); + + let s1 = repos.partitions().snapshot(partition2.id).await.unwrap(); + validate_partition_snapshot(repos.as_mut(), &s1).await; + + repos + .parquet_files() + .create_upgrade_delete( + f2.partition_id, + &[f2.object_store_id], + &[], + &[], + CompactionLevel::Initial, + ) + .await + .unwrap(); + let files = + list_parquet_files_by_namespace_not_to_delete(Arc::clone(&catalog), namespace2.id).await; + assert_eq!(vec![f1.clone(), f3.clone()], files); + + // Cannot delete file twice + let err = repos + .parquet_files() + .create_upgrade_delete( + partition2.id, + &[f2.object_store_id, f3.object_store_id], + &[], + &[], + CompactionLevel::Initial, + ) + .await + .unwrap_err(); + assert_matches!(err, Error::NotFound { .. }); + + let err = repos + .parquet_files() + .create_upgrade_delete( + partition2.id, + &[f2.object_store_id], + &[f3.object_store_id], + &[], + CompactionLevel::Initial, + ) + .await + .unwrap_err(); + assert_matches!(err, Error::NotFound { .. }); + + // Cannot upgrade deleted file + let err = repos + .parquet_files() + .create_upgrade_delete( + partition2.id, + &[f3.object_store_id], + &[f2.object_store_id], + &[], + CompactionLevel::Initial, + ) + .await + .unwrap_err(); + assert_matches!(err, Error::NotFound { .. }); + + // Failed transactions don't modify + let files = + list_parquet_files_by_namespace_not_to_delete(Arc::clone(&catalog), namespace2.id).await; + assert_eq!(vec![f1.clone(), f3.clone()], files); + + let s2 = repos.partitions().snapshot(partition2.id).await.unwrap(); + assert_gt(s2.generation(), s1.generation()); + validate_partition_snapshot(repos.as_mut(), &s2).await; + + let files = list_parquet_files_by_namespace_not_to_delete( + Arc::clone(&catalog), + NamespaceId::new(i64::MAX), + ) + .await; + assert!(files.is_empty()); + + // test delete_old_ids_only + let older_than = Timestamp::new( + (catalog.time_provider().now() + Duration::from_secs(100)).timestamp_nanos(), + ); + let ids = repos + .parquet_files() + .delete_old_ids_only(older_than) + .await + .unwrap(); + assert_eq!(ids.len(), 1); + + let s3 = repos.partitions().snapshot(partition2.id).await.unwrap(); + assert_ge(s3.generation(), s2.generation()); // no new snapshot required, but some backends will generate a new one + validate_partition_snapshot(repos.as_mut(), &s3).await; + + // test retention-based flagging for deletion + // Since mem catalog has default retention 1 hour, let us first set it to 0 means infinite + let namespaces = repos + .namespaces() + .list(SoftDeletedRows::AllRows) + .await + .expect("listing namespaces"); + for namespace in namespaces { + repos + .namespaces() + .update_retention_period(&namespace.name, None) // infinite + .await + .unwrap(); + } + + // 1. with no retention period set on the ns, nothing should get flagged + let ids = repos + .parquet_files() + .flag_for_delete_by_retention() + .await + .unwrap(); + assert!(ids.is_empty()); + // 2. set ns retention period to one hour then create some files before and after and + // ensure correct files get deleted + repos + .namespaces() + .update_retention_period(&namespace2.name, Some(60 * 60 * 1_000_000_000)) // 1 hour + .await + .unwrap(); + let f4_params = ParquetFileParams { + object_store_id: ObjectStoreId::new(), + max_time: Timestamp::new( + // a bit over an hour ago + (catalog.time_provider().now() - Duration::from_secs(60 * 65)).timestamp_nanos(), + ), + ..f3_params + }; + let f4 = repos + .parquet_files() + .create(f4_params.clone()) + .await + .unwrap(); + let f5_params = ParquetFileParams { + object_store_id: ObjectStoreId::new(), + max_time: Timestamp::new( + // a bit under an hour ago + (catalog.time_provider().now() - Duration::from_secs(60 * 55)).timestamp_nanos(), + ), + ..f4_params + }; + let f5 = repos + .parquet_files() + .create(f5_params.clone()) + .await + .unwrap(); + let ids = repos + .parquet_files() + .flag_for_delete_by_retention() + .await + .unwrap(); + assert!(ids.len() > 1); // it's also going to flag f1, f2 & f3 because they have low max + // timestamps but i don't want this test to be brittle if those + // values change so i'm not asserting len == 4 + let f4 = repos + .parquet_files() + .get_by_object_store_id(f4.object_store_id) + .await + .unwrap() + .unwrap(); + assert_matches!(f4.to_delete, Some(_)); // f4 is > 1hr old + let f5 = repos + .parquet_files() + .get_by_object_store_id(f5.object_store_id) + .await + .unwrap() + .unwrap(); + assert_matches!(f5.to_delete, None); // f5 is < 1hr old + + let s4 = repos.partitions().snapshot(partition2.id).await.unwrap(); + assert_gt(s4.generation(), s3.generation()); + validate_partition_snapshot(repos.as_mut(), &s4).await; + + // call flag_for_delete_by_retention() again and nothing should be flagged because they've + // already been flagged + let ids = repos + .parquet_files() + .flag_for_delete_by_retention() + .await + .unwrap(); + assert!(ids.is_empty()); + + // test that flag_for_delete_by_retention respects UPDATE LIMIT + // create limit + the meaning of life parquet files that are all older than the retention (>1hr) + const LIMIT: usize = 1000; + const MOL: usize = 42; + let now = catalog.time_provider().now(); + let params = (0..LIMIT + MOL) + .map(|_| { + ParquetFileParams { + object_store_id: ObjectStoreId::new(), + max_time: Timestamp::new( + // a bit over an hour ago + (now - Duration::from_secs(60 * 65)).timestamp_nanos(), + ), + ..f1_params.clone() + } + }) + .collect::>(); + repos + .parquet_files() + .create_upgrade_delete( + f1_params.partition_id, + &[], + &[], + ¶ms, + CompactionLevel::Initial, + ) + .await + .unwrap(); + let ids = repos + .parquet_files() + .flag_for_delete_by_retention() + .await + .unwrap(); + assert_eq!(ids.len(), LIMIT); + let ids = repos + .parquet_files() + .flag_for_delete_by_retention() + .await + .unwrap(); + assert_eq!(ids.len(), MOL); // second call took remainder + let ids = repos + .parquet_files() + .flag_for_delete_by_retention() + .await + .unwrap(); + assert_eq!(ids.len(), 0); // none left + + // test create_update_delete + let f6_params = ParquetFileParams { + object_store_id: ObjectStoreId::new(), + ..f5_params + }; + let f6 = repos + .parquet_files() + .create(f6_params.clone()) + .await + .unwrap(); + + let f7_params = ParquetFileParams { + object_store_id: ObjectStoreId::new(), + ..f6_params + }; + let f1_uuid = f1.object_store_id; + let f6_uuid = f6.object_store_id; + let f5_uuid = f5.object_store_id; + let cud = repos + .parquet_files() + .create_upgrade_delete( + f5.partition_id, + &[f5.object_store_id], + &[f6.object_store_id], + &[f7_params.clone()], + CompactionLevel::Final, + ) + .await + .unwrap(); + + assert_eq!(cud.len(), 1); + let f5_delete = repos + .parquet_files() + .get_by_object_store_id(f5_uuid) + .await + .unwrap() + .unwrap(); + assert_matches!(f5_delete.to_delete, Some(_)); + + let f6_compaction_level = repos + .parquet_files() + .get_by_object_store_id(f6_uuid) + .await + .unwrap() + .unwrap(); + + assert_matches!(f6_compaction_level.compaction_level, CompactionLevel::Final); + + let f7 = repos + .parquet_files() + .get_by_object_store_id(f7_params.object_store_id) + .await + .unwrap() + .unwrap(); + + let f7_uuid = f7.object_store_id; + + // test create_update_delete transaction (rollback because f7 already exists) + let cud = repos + .parquet_files() + .create_upgrade_delete( + partition2.id, + &[], + &[], + &[f7_params.clone()], + CompactionLevel::Final, + ) + .await; + + assert_matches!( + cud, + Err(Error::AlreadyExists { + descr + }) if descr == f7_params.object_store_id.to_string() + ); + + let f1_to_delete = repos + .parquet_files() + .get_by_object_store_id(f1_uuid) + .await + .unwrap() + .unwrap(); + assert_matches!(f1_to_delete.to_delete, Some(_)); + + let f7_not_delete = repos + .parquet_files() + .get_by_object_store_id(f7_uuid) + .await + .unwrap() + .unwrap(); + assert_matches!(f7_not_delete.to_delete, None); + + // test exists_by_object_store_id_batch returns parquet files by object store id + let does_not_exist = ObjectStoreId::new(); + let mut present = repos + .parquet_files() + .exists_by_object_store_id_batch(vec![f1_uuid, f7_uuid, does_not_exist]) + .await + .unwrap(); + let mut expected = vec![f1_uuid, f7_uuid]; + present.sort(); + expected.sort(); + assert_eq!(present, expected); + + let s5 = repos.partitions().snapshot(partition2.id).await.unwrap(); + assert_gt(s5.generation(), s4.generation()); + validate_partition_snapshot(repos.as_mut(), &s5).await; + + // Cannot mix partition IDs + let partition3 = repos + .partitions() + .create_or_get("three".into(), table.id) + .await + .unwrap(); + + let ts4 = repos.tables().snapshot(table.id).await.unwrap(); + validate_table_snapshot(repos.as_mut(), &ts4).await; + assert_gt(ts4.generation(), ts1.generation()); + + let f8_params = ParquetFileParams { + object_store_id: ObjectStoreId::new(), + partition_id: partition3.id, + ..f7_params + }; + let err = repos + .parquet_files() + .create_upgrade_delete( + partition2.id, + &[f7_uuid], + &[], + &[f8_params.clone()], + CompactionLevel::Final, + ) + .await + .unwrap_err() + .to_string(); + + assert!( + err.contains("Inconsistent ParquetFileParams, expected PartitionId"), + "{err}" + ); + + let list = repos + .parquet_files() + .list_by_partition_not_to_delete_batch(vec![partition2.id]) + .await + .unwrap(); + assert_eq!(list.len(), 2); + + repos + .parquet_files() + .create_upgrade_delete( + partition3.id, + &[], + &[], + &[f8_params.clone()], + CompactionLevel::Final, + ) + .await + .unwrap(); + + let files = repos + .parquet_files() + .list_by_partition_not_to_delete_batch(vec![partition3.id]) + .await + .unwrap(); + assert_eq!(files.len(), 1); + let f8_uuid = files[0].object_store_id; + + let files = repos + .parquet_files() + .list_by_partition_not_to_delete_batch(vec![]) + .await + .unwrap(); + assert_eq!(files.len(), 0); + let files = repos + .parquet_files() + .list_by_partition_not_to_delete_batch(vec![partition2.id, partition3.id]) + .await + .unwrap(); + assert_eq!(files.len(), 3); + let files = repos + .parquet_files() + .list_by_partition_not_to_delete_batch(vec![ + partition2.id, + PartitionId::new(i64::MAX), + partition3.id, + partition2.id, + ]) + .await + .unwrap(); + assert_eq!(files.len(), 3); + + let err = repos + .parquet_files() + .create_upgrade_delete(partition2.id, &[f8_uuid], &[], &[], CompactionLevel::Final) + .await + .unwrap_err(); + + assert_matches!(err, Error::NotFound { .. }); + + let err = repos + .parquet_files() + .create_upgrade_delete(partition2.id, &[], &[f8_uuid], &[], CompactionLevel::Final) + .await + .unwrap_err(); + + assert_matches!(err, Error::NotFound { .. }); + + repos + .parquet_files() + .create_upgrade_delete(partition3.id, &[f8_uuid], &[], &[], CompactionLevel::Final) + .await + .unwrap(); + + // take snapshot of unknown partition + let err = repos + .partitions() + .snapshot(PartitionId::new(i64::MAX)) + .await + .unwrap_err(); + assert_matches!(err, Error::NotFound { .. }); +} + +async fn test_parquet_file_delete_broken(catalog: Arc) { + let mut repos = catalog.repositories(); + let namespace_1 = arbitrary_namespace(&mut *repos, "retention_broken_1").await; + let namespace_2 = repos + .namespaces() + .create( + &NamespaceName::new("retention_broken_2").unwrap(), + None, + Some(1), + None, + ) + .await + .unwrap(); + let table_1 = arbitrary_table(&mut *repos, "test_table", &namespace_1).await; + let table_2 = arbitrary_table(&mut *repos, "test_table", &namespace_2).await; + let partition_1 = repos + .partitions() + .create_or_get("one".into(), table_1.id) + .await + .unwrap(); + let partition_2 = repos + .partitions() + .create_or_get("one".into(), table_2.id) + .await + .unwrap(); + + let parquet_file_params_1 = arbitrary_parquet_file_params(&namespace_1, &table_1, &partition_1); + let parquet_file_params_2 = arbitrary_parquet_file_params(&namespace_2, &table_2, &partition_2); + let _parquet_file_1 = repos + .parquet_files() + .create(parquet_file_params_1) + .await + .unwrap(); + let parquet_file_2 = repos + .parquet_files() + .create(parquet_file_params_2) + .await + .unwrap(); + + let ids = repos + .parquet_files() + .flag_for_delete_by_retention() + .await + .unwrap(); + assert_eq!( + ids, + vec![(parquet_file_2.partition_id, parquet_file_2.object_store_id)] + ); +} + +async fn test_partitions_new_file_between(catalog: Arc) { + let mut repos = catalog.repositories(); + let namespace = arbitrary_namespace(&mut *repos, "test_partitions_new_file_between").await; + let table = arbitrary_table(&mut *repos, "test_table_for_new_file_between", &namespace).await; + + // param for the tests + let time_now = Timestamp::from(catalog.time_provider().now()); + let time_one_hour_ago = Timestamp::from(catalog.time_provider().hours_ago(1)); + let time_two_hour_ago = Timestamp::from(catalog.time_provider().hours_ago(2)); + let time_three_hour_ago = Timestamp::from(catalog.time_provider().hours_ago(3)); + let time_five_hour_ago = Timestamp::from(catalog.time_provider().hours_ago(5)); + let time_six_hour_ago = Timestamp::from(catalog.time_provider().hours_ago(6)); + + // Db has no partitions + let partitions = repos + .partitions() + .partitions_new_file_between(time_two_hour_ago, None) + .await + .unwrap(); + assert!(partitions.is_empty()); + + // ----------------- + // PARTITION one + // The DB has 1 partition but it does not have any file + let partition1 = repos + .partitions() + .create_or_get("one".into(), table.id) + .await + .unwrap(); + let partitions = repos + .partitions() + .partitions_new_file_between(time_two_hour_ago, None) + .await + .unwrap(); + assert!(partitions.is_empty()); + + // create files for partition one + let parquet_file_params = arbitrary_parquet_file_params(&namespace, &table, &partition1); + + // create a deleted L0 file that was created 3 hours ago + let delete_l0_file = repos + .parquet_files() + .create(parquet_file_params.clone()) + .await + .unwrap(); + repos + .parquet_files() + .create_upgrade_delete( + delete_l0_file.partition_id, + &[delete_l0_file.object_store_id], + &[], + &[], + CompactionLevel::Initial, + ) + .await + .unwrap(); + let partitions = repos + .partitions() + .partitions_new_file_between(time_two_hour_ago, None) + .await + .unwrap(); + assert!(partitions.is_empty()); + let partitions = repos + .partitions() + .partitions_new_file_between(time_two_hour_ago, Some(time_one_hour_ago)) + .await + .unwrap(); + assert!(partitions.is_empty()); + let partitions = repos + .partitions() + .partitions_new_file_between(time_three_hour_ago, Some(time_one_hour_ago)) + .await + .unwrap(); + assert!(partitions.is_empty()); + + // create a deleted L0 file that was created 1 hour ago + let l0_one_hour_ago_file_params = ParquetFileParams { + object_store_id: ObjectStoreId::new(), + created_at: time_one_hour_ago, + ..parquet_file_params.clone() + }; + repos + .parquet_files() + .create(l0_one_hour_ago_file_params.clone()) + .await + .unwrap(); + // partition one should be returned + let partitions = repos + .partitions() + .partitions_new_file_between(time_two_hour_ago, None) + .await + .unwrap(); + assert_eq!(partitions.len(), 1); + assert_eq!(partitions[0], partition1.id); + let partitions = repos + .partitions() + .partitions_new_file_between(time_two_hour_ago, Some(time_now)) + .await + .unwrap(); + assert_eq!(partitions.len(), 1); + assert_eq!(partitions[0], partition1.id); + let partitions = repos + .partitions() + .partitions_new_file_between(time_three_hour_ago, Some(time_now)) + .await + .unwrap(); + assert_eq!(partitions.len(), 1); + assert_eq!(partitions[0], partition1.id); + let partitions = repos + .partitions() + .partitions_new_file_between(time_three_hour_ago, Some(time_two_hour_ago)) + .await + .unwrap(); + assert!(partitions.is_empty()); + + // ----------------- + // PARTITION two + // Partition two without any file + let partition2 = repos + .partitions() + .create_or_get("two".into(), table.id) + .await + .unwrap(); + // should return partition one only + let partitions = repos + .partitions() + .partitions_new_file_between(time_two_hour_ago, None) + .await + .unwrap(); + assert_eq!(partitions.len(), 1); + assert_eq!(partitions[0], partition1.id); + let partitions = repos + .partitions() + .partitions_new_file_between(time_three_hour_ago, Some(time_now)) + .await + .unwrap(); + assert_eq!(partitions.len(), 1); + assert_eq!(partitions[0], partition1.id); + + // Add a L0 file created 5 hours ago + let l0_five_hour_ago_file_params = ParquetFileParams { + object_store_id: ObjectStoreId::new(), + created_at: time_five_hour_ago, + partition_id: partition2.id, + partition_hash_id: partition2.hash_id().cloned(), + ..parquet_file_params.clone() + }; + repos + .parquet_files() + .create(l0_five_hour_ago_file_params.clone()) + .await + .unwrap(); + // still return partition one only + let partitions = repos + .partitions() + .partitions_new_file_between(time_two_hour_ago, None) + .await + .unwrap(); + assert_eq!(partitions.len(), 1); + assert_eq!(partitions[0], partition1.id); + let partitions = repos + .partitions() + .partitions_new_file_between(time_three_hour_ago, Some(time_now)) + .await + .unwrap(); + assert_eq!(partitions.len(), 1); + assert_eq!(partitions[0], partition1.id); + let partitions = repos + .partitions() + .partitions_new_file_between(time_three_hour_ago, Some(time_now)) + .await + .unwrap(); + assert_eq!(partitions.len(), 1); + assert_eq!(partitions[0], partition1.id); + // Between six and three hours ago, return only partition 2 + let partitions = repos + .partitions() + .partitions_new_file_between(time_six_hour_ago, Some(time_three_hour_ago)) + .await + .unwrap(); + assert_eq!(partitions.len(), 1); + assert_eq!(partitions[0], partition2.id); + + // Add an L1 file created just now + let l1_file_params = ParquetFileParams { + object_store_id: ObjectStoreId::new(), + created_at: time_now, + partition_id: partition2.id, + partition_hash_id: partition2.hash_id().cloned(), + compaction_level: CompactionLevel::FileNonOverlapped, + ..parquet_file_params.clone() + }; + repos + .parquet_files() + .create(l1_file_params.clone()) + .await + .unwrap(); + // should return both partitions + let mut partitions = repos + .partitions() + .partitions_new_file_between(time_two_hour_ago, None) + .await + .unwrap(); + assert_eq!(partitions.len(), 2); + partitions.sort(); + assert_eq!(partitions[0], partition1.id); + assert_eq!(partitions[1], partition2.id); + // Only return partition1: the creation time must be strictly less than the maximum time, + // not equal + let mut partitions = repos + .partitions() + .partitions_new_file_between(time_three_hour_ago, Some(time_now)) + .await + .unwrap(); + assert_eq!(partitions.len(), 1); + partitions.sort(); + assert_eq!(partitions[0], partition1.id); + // Between six and three hours ago, return none + let partitions = repos + .partitions() + .partitions_new_file_between(time_six_hour_ago, Some(time_three_hour_ago)) + .await + .unwrap(); + assert!(partitions.is_empty()); + + // ----------------- + // PARTITION three + // Partition three without any file + let partition3 = repos + .partitions() + .create_or_get("three".into(), table.id) + .await + .unwrap(); + // should return partition one and two only + let mut partitions = repos + .partitions() + .partitions_new_file_between(time_two_hour_ago, None) + .await + .unwrap(); + assert_eq!(partitions.len(), 2); + partitions.sort(); + assert_eq!(partitions[0], partition1.id); + assert_eq!(partitions[1], partition2.id); + // Only return partition1: the creation time must be strictly less than the maximum time, + // not equal + let partitions = repos + .partitions() + .partitions_new_file_between(time_three_hour_ago, Some(time_now)) + .await + .unwrap(); + assert_eq!(partitions.len(), 1); + assert_eq!(partitions[0], partition1.id); + // When the maximum time is greater than the creation time of partition2, return it + let mut partitions = repos + .partitions() + .partitions_new_file_between(time_three_hour_ago, Some(time_now + 1)) + .await + .unwrap(); + assert_eq!(partitions.len(), 2); + partitions.sort(); + assert_eq!(partitions[0], partition1.id); + assert_eq!(partitions[1], partition2.id); + // Between six and three hours ago, return none + let partitions = repos + .partitions() + .partitions_new_file_between(time_six_hour_ago, Some(time_three_hour_ago)) + .await + .unwrap(); + assert!(partitions.is_empty()); + + // Add an L2 file created just now for partition three + let l2_file_params = ParquetFileParams { + object_store_id: ObjectStoreId::new(), + created_at: time_now, + partition_id: partition3.id, + partition_hash_id: partition3.hash_id().cloned(), + compaction_level: CompactionLevel::Final, + ..parquet_file_params.clone() + }; + repos + .parquet_files() + .create(l2_file_params.clone()) + .await + .unwrap(); + // now should return partition one two and three + let mut partitions = repos + .partitions() + .partitions_new_file_between(time_two_hour_ago, None) + .await + .unwrap(); + assert_eq!(partitions.len(), 3); + partitions.sort(); + assert_eq!(partitions[0], partition1.id); + assert_eq!(partitions[1], partition2.id); + assert_eq!(partitions[2], partition3.id); + // Only return partition1: the creation time must be strictly less than the maximum time, + // not equal + let partitions = repos + .partitions() + .partitions_new_file_between(time_three_hour_ago, Some(time_now)) + .await + .unwrap(); + assert_eq!(partitions.len(), 1); + assert_eq!(partitions[0], partition1.id); + // Between six and three hours ago, return none + let partitions = repos + .partitions() + .partitions_new_file_between(time_six_hour_ago, Some(time_three_hour_ago)) + .await + .unwrap(); + assert!(partitions.is_empty()); + + // add an L0 file created one hour ago for partition three + let l0_one_hour_ago_file_params = ParquetFileParams { + object_store_id: ObjectStoreId::new(), + created_at: time_one_hour_ago, + partition_id: partition3.id, + partition_hash_id: partition3.hash_id().cloned(), + ..parquet_file_params.clone() + }; + repos + .parquet_files() + .create(l0_one_hour_ago_file_params.clone()) + .await + .unwrap(); + // should return all partitions + let mut partitions = repos + .partitions() + .partitions_new_file_between(time_two_hour_ago, None) + .await + .unwrap(); + assert_eq!(partitions.len(), 3); + partitions.sort(); + assert_eq!(partitions[0], partition1.id); + assert_eq!(partitions[1], partition2.id); + assert_eq!(partitions[2], partition3.id); + // Only return partitions 1 and 3; 2 was created just now + let mut partitions = repos + .partitions() + .partitions_new_file_between(time_three_hour_ago, Some(time_now)) + .await + .unwrap(); + assert_eq!(partitions.len(), 2); + partitions.sort(); + assert_eq!(partitions[0], partition1.id); + assert_eq!(partitions[1], partition3.id); + // Between six and three hours ago, return none + let partitions = repos + .partitions() + .partitions_new_file_between(time_six_hour_ago, Some(time_three_hour_ago)) + .await + .unwrap(); + assert!(partitions.is_empty()); +} + +async fn test_list_by_partiton_not_to_delete(catalog: Arc) { + let mut repos = catalog.repositories(); + let namespace = arbitrary_namespace( + &mut *repos, + "namespace_parquet_file_test_list_by_partiton_not_to_delete", + ) + .await; + let table = arbitrary_table(&mut *repos, "test_table", &namespace).await; + + let partition = repos + .partitions() + .create_or_get("test_list_by_partiton_not_to_delete_one".into(), table.id) + .await + .unwrap(); + let partition2 = repos + .partitions() + .create_or_get("test_list_by_partiton_not_to_delete_two".into(), table.id) + .await + .unwrap(); + + let parquet_file_params = arbitrary_parquet_file_params(&namespace, &table, &partition); + + let parquet_file = repos + .parquet_files() + .create(parquet_file_params.clone()) + .await + .unwrap(); + let delete_file_params = ParquetFileParams { + object_store_id: ObjectStoreId::new(), + ..parquet_file_params.clone() + }; + let delete_file = repos + .parquet_files() + .create(delete_file_params) + .await + .unwrap(); + repos + .parquet_files() + .create_upgrade_delete( + partition.id, + &[delete_file.object_store_id], + &[], + &[], + CompactionLevel::Initial, + ) + .await + .unwrap(); + let level1_file_params = ParquetFileParams { + object_store_id: ObjectStoreId::new(), + ..parquet_file_params.clone() + }; + let mut level1_file = repos + .parquet_files() + .create(level1_file_params) + .await + .unwrap(); + repos + .parquet_files() + .create_upgrade_delete( + partition.id, + &[], + &[level1_file.object_store_id], + &[], + CompactionLevel::FileNonOverlapped, + ) + .await + .unwrap(); + level1_file.compaction_level = CompactionLevel::FileNonOverlapped; + + let other_partition_params = ParquetFileParams { + partition_id: partition2.id, + partition_hash_id: partition2.hash_id().cloned(), + object_store_id: ObjectStoreId::new(), + ..parquet_file_params.clone() + }; + let _partition2_file = repos + .parquet_files() + .create(other_partition_params) + .await + .unwrap(); + + let files = repos + .parquet_files() + .list_by_partition_not_to_delete_batch(vec![partition.id]) + .await + .unwrap(); + assert_eq!(files.len(), 2); + + let mut file_ids: Vec<_> = files.into_iter().map(|f| f.id).collect(); + file_ids.sort(); + let mut expected_ids = vec![parquet_file.id, level1_file.id]; + expected_ids.sort(); + assert_eq!(file_ids, expected_ids); + + // Using the catalog partition ID should return the same files, even if the Parquet file + // records don't have the partition ID on them (which is the default now) + let files = repos + .parquet_files() + .list_by_partition_not_to_delete_batch(vec![partition.id]) + .await + .unwrap(); + assert_eq!(files.len(), 2); + + let mut file_ids: Vec<_> = files.into_iter().map(|f| f.id).collect(); + file_ids.sort(); + let mut expected_ids = vec![parquet_file.id, level1_file.id]; + expected_ids.sort(); + assert_eq!(file_ids, expected_ids); +} + +async fn test_update_to_compaction_level_1(catalog: Arc) { + let mut repos = catalog.repositories(); + let namespace = + arbitrary_namespace(&mut *repos, "namespace_update_to_compaction_level_1_test").await; + let table = arbitrary_table(&mut *repos, "update_table", &namespace).await; + let partition = repos + .partitions() + .create_or_get("test_update_to_compaction_level_1_one".into(), table.id) + .await + .unwrap(); + + // Set up the window of times we're interested in level 1 files for + let query_min_time = Timestamp::new(5); + let query_max_time = Timestamp::new(10); + + // Create a file with times entirely within the window + let mut parquet_file_params = arbitrary_parquet_file_params(&namespace, &table, &partition); + parquet_file_params.min_time = query_min_time + 1; + parquet_file_params.max_time = query_max_time - 1; + let parquet_file = repos + .parquet_files() + .create(parquet_file_params.clone()) + .await + .unwrap(); + + // Create a file that will remain as level 0 + let level_0_params = ParquetFileParams { + object_store_id: ObjectStoreId::new(), + ..parquet_file_params.clone() + }; + repos.parquet_files().create(level_0_params).await.unwrap(); + + // Make parquet_file compaction level 1 + let created = repos + .parquet_files() + .create_upgrade_delete( + parquet_file.partition_id, + &[], + &[parquet_file.object_store_id], + &[], + CompactionLevel::FileNonOverlapped, + ) + .await + .unwrap(); + assert_eq!(created, vec![]); + + // remove namespace to avoid it from affecting later tests + repos + .namespaces() + .soft_delete("namespace_update_to_compaction_level_1_test") + .await + .expect("delete namespace should succeed"); +} + +/// Assert that a namespace deletion does NOT cascade to the tables/schema +/// items/parquet files/etc. +/// +/// Removal of this entities breaks the invariant that once created, a row +/// always exists for the lifetime of an IOx process, and causes the system +/// to panic in multiple components. It's also ineffective, because most +/// components maintain a cache of at least one of these entities. +/// +/// Instead soft deleted namespaces should have their files GC'd like a +/// normal parquet file deletion, removing the rows once they're no longer +/// being actively used by the system. This is done by waiting a long time +/// before deleting records, and whilst isn't perfect, it is largely +/// effective. +async fn test_delete_namespace(catalog: Arc) { + let mut repos = catalog.repositories(); + let namespace_1 = arbitrary_namespace(&mut *repos, "namespace_test_delete_namespace_1").await; + let table_1 = arbitrary_table(&mut *repos, "test_table_1", &namespace_1).await; + let _c = repos + .columns() + .create_or_get("column_test_1", table_1.id, ColumnType::Tag) + .await + .unwrap(); + let partition_1 = repos + .partitions() + .create_or_get("test_delete_namespace_one".into(), table_1.id) + .await + .unwrap(); + + // parquet files + let parquet_file_params = arbitrary_parquet_file_params(&namespace_1, &table_1, &partition_1); + repos + .parquet_files() + .create(parquet_file_params.clone()) + .await + .unwrap(); + let parquet_file_params_2 = ParquetFileParams { + object_store_id: ObjectStoreId::new(), + min_time: Timestamp::new(200), + max_time: Timestamp::new(300), + ..parquet_file_params + }; + repos + .parquet_files() + .create(parquet_file_params_2.clone()) + .await + .unwrap(); + + // we've now created a namespace with a table and parquet files. before we test deleting + // it, let's create another so we can ensure that doesn't get deleted. + let namespace_2 = arbitrary_namespace(&mut *repos, "namespace_test_delete_namespace_2").await; + let table_2 = arbitrary_table(&mut *repos, "test_table_2", &namespace_2).await; + let _c = repos + .columns() + .create_or_get("column_test_2", table_2.id, ColumnType::Tag) + .await + .unwrap(); + let partition_2 = repos + .partitions() + .create_or_get("test_delete_namespace_two".into(), table_2.id) + .await + .unwrap(); + + // parquet files + let parquet_file_params = arbitrary_parquet_file_params(&namespace_2, &table_2, &partition_2); + repos + .parquet_files() + .create(parquet_file_params.clone()) + .await + .unwrap(); + let parquet_file_params_2 = ParquetFileParams { + object_store_id: ObjectStoreId::new(), + min_time: Timestamp::new(200), + max_time: Timestamp::new(300), + ..parquet_file_params + }; + repos + .parquet_files() + .create(parquet_file_params_2.clone()) + .await + .unwrap(); + + // now delete namespace_1 and assert it's all gone and none of + // namespace_2 is gone + repos + .namespaces() + .soft_delete("namespace_test_delete_namespace_1") + .await + .expect("delete namespace should succeed"); + // assert that namespace is soft-deleted, but the table, column, and parquet files are all + // still there. + assert!(repos + .namespaces() + .get_by_id(namespace_1.id, SoftDeletedRows::ExcludeDeleted) + .await + .expect("get namespace should succeed") + .is_none()); + assert_eq!( + repos + .namespaces() + .get_by_id(namespace_1.id, SoftDeletedRows::AllRows) + .await + .expect("get namespace should succeed") + .map(|mut v| { + // The only change after soft-deletion should be the deleted_at + // field being set - this block normalises that field, so that + // the before/after can be asserted as equal. + v.deleted_at = None; + v + }) + .expect("should see soft-deleted row"), + namespace_1 + ); + assert_eq!( + repos + .tables() + .get_by_id(table_1.id) + .await + .expect("get table should succeed") + .expect("should return row"), + table_1 + ); + assert_eq!( + repos + .columns() + .list_by_namespace_id(namespace_1.id) + .await + .expect("listing columns should succeed") + .len(), + 1 + ); + assert_eq!( + repos + .columns() + .list_by_table_id(table_1.id) + .await + .expect("listing columns should succeed") + .len(), + 1 + ); + + // partition's get_by_id should succeed + repos + .partitions() + .get_by_id_batch(&[partition_1.id]) + .await + .unwrap() + .into_iter() + .next() + .unwrap(); + + // assert that the namespace, table, column, and parquet files for namespace_2 are still + // there + assert!(repos + .namespaces() + .get_by_id(namespace_2.id, SoftDeletedRows::ExcludeDeleted) + .await + .expect("get namespace should succeed") + .is_some()); + + assert!(repos + .tables() + .get_by_id(table_2.id) + .await + .expect("get table should succeed") + .is_some()); + assert_eq!( + repos + .columns() + .list_by_namespace_id(namespace_2.id) + .await + .expect("listing columns should succeed") + .len(), + 1 + ); + assert_eq!( + repos + .columns() + .list_by_table_id(table_2.id) + .await + .expect("listing columns should succeed") + .len(), + 1 + ); + + // partition's get_by_id should succeed + repos + .partitions() + .get_by_id_batch(&[partition_2.id]) + .await + .unwrap() + .into_iter() + .next() + .unwrap(); +} + +/// Upsert a namespace called `namespace_name` and write `lines` to it. +async fn populate_namespace( + repos: &mut R, + namespace_name: &str, + lines: &str, +) -> (Namespace, NamespaceSchema) +where + R: RepoCollection + ?Sized, +{ + let namespace = repos + .namespaces() + .create( + &NamespaceName::new(namespace_name).unwrap(), + None, + None, + None, + ) + .await; + + let namespace = match namespace { + Ok(v) => v, + Err(Error::AlreadyExists { .. }) => repos + .namespaces() + .get_by_name(namespace_name, SoftDeletedRows::AllRows) + .await + .unwrap() + .unwrap(), + e @ Err(_) => e.unwrap(), + }; + + let batches = mutable_batch_lp::lines_to_batches(lines, 42).unwrap(); + let batches = batches.iter().map(|(table, batch)| (table.as_str(), batch)); + let ns = NamespaceSchema::new_empty_from(&namespace); + + let schema = validate_or_insert_schema(batches, &ns, repos) + .await + .expect("validate schema failed") + .unwrap_or(ns); + + (namespace, schema) +} + +async fn test_list_schemas(catalog: Arc) { + let mut repos = catalog.repositories(); + + let ns1 = populate_namespace( + repos.deref_mut(), + "ns1", + "cpu,tag=1 field=1i\nanother,tag=1 field=1.0", + ) + .await; + let ns2 = populate_namespace( + repos.deref_mut(), + "ns2", + "cpu,tag=1 field=1i\nsomethingelse field=1u", + ) + .await; + + // Otherwise the in-mem catalog deadlocks.... (but not postgres) + drop(repos); + + let got = list_schemas(&*catalog) + .await + .expect("should be able to list the schemas") + .collect::>(); + + assert!(got.contains(&ns1), "{:#?}\n\nwant{:#?}", got, &ns1); + assert!(got.contains(&ns2), "{:#?}\n\nwant{:#?}", got, &ns2); +} + +async fn test_list_schemas_soft_deleted_rows(catalog: Arc) { + let mut repos = catalog.repositories(); + + let ns1 = populate_namespace( + repos.deref_mut(), + "ns1", + "cpu,tag=1 field=1i\nanother,tag=1 field=1.0", + ) + .await; + let ns2 = populate_namespace( + repos.deref_mut(), + "ns2", + "cpu,tag=1 field=1i\nsomethingelse field=1u", + ) + .await; + + repos + .namespaces() + .soft_delete(&ns2.0.name) + .await + .expect("failed to soft delete namespace"); + + // Otherwise the in-mem catalog deadlocks.... (but not postgres) + drop(repos); + + let got = list_schemas(&*catalog) + .await + .expect("should be able to list the schemas") + .collect::>(); + + assert!(got.contains(&ns1), "{:#?}\n\nwant{:#?}", got, &ns1); + assert!(!got.contains(&ns2), "{:#?}\n\n do not want{:#?}", got, &ns2); +} + +/// Ensure that we can create two repo objects and that they instantly share their state. +/// +/// This is a regression test for . +async fn test_two_repos(catalog: Arc) { + let mut repos_1 = catalog.repositories(); + let mut repos_2 = catalog.repositories(); + let repo_1 = repos_1.namespaces(); + let repo_2 = repos_2.namespaces(); + + let namespace_name = NamespaceName::new("test_namespace").unwrap(); + repo_1 + .create(&namespace_name, None, None, None) + .await + .unwrap(); + + repo_2 + .get_by_name(&namespace_name, SoftDeletedRows::AllRows) + .await + .unwrap() + .unwrap(); +} + +async fn test_partition_create_or_get_idempotent(catalog: Arc) { + let mut repos = catalog.repositories(); + + let namespace = arbitrary_namespace(&mut *repos, "ns4").await; + let table_id = arbitrary_table(&mut *repos, "table", &namespace).await.id; + + let key = PartitionKey::from("bananas"); + + let hash_id = PartitionHashId::new(table_id, &key); + + let a = repos + .partitions() + .create_or_get(key.clone(), table_id) + .await + .expect("should create OK"); + + assert_eq!(a.hash_id().unwrap(), &hash_id); + // Test: sort_key_ids from partition_create_or_get_idempotent + assert!(a.sort_key_ids().is_none()); + + // Call create_or_get for the same (key, table_id) pair, to ensure the write is idempotent. + let b = repos + .partitions() + .create_or_get(key.clone(), table_id) + .await + .expect("idempotent write should succeed"); + + assert_eq!(a, b); + + // Check that the hash_id is saved in the database and is returned when queried. + let table_partitions = repos.partitions().list_by_table_id(table_id).await.unwrap(); + assert_eq!(table_partitions.len(), 1); + assert_eq!(table_partitions[0].hash_id().unwrap(), &hash_id); + + // Test: sort_key_ids from partition_create_or_get_idempotent + assert!(table_partitions[0].sort_key_ids().is_none()); +} + +#[track_caller] +fn assert_metric_hit(metrics: &metric::Registry, name: &'static str) { + let histogram = metrics + .get_instrument::>("catalog_op_duration") + .expect("failed to read metric") + .get_observer(&Attributes::from(&[("op", name), ("result", "success")])) + .expect("failed to get observer") + .fetch(); + + let hit_count = histogram.sample_count(); + assert!(hit_count > 0, "metric did not record any calls"); +} + +async fn test_column_create_or_get_many_unchecked(clean_state: R) +where + R: Fn() -> F + Send + Sync, + F: Future> + Send, +{ + // Issue a few calls to create_or_get_many that contain distinct columns and + // covers the full set of column types. + test_column_create_or_get_many_unchecked_sub( + clean_state().await, + &[ + &[ + ("test1", ColumnType::I64), + ("test2", ColumnType::U64), + ("test3", ColumnType::F64), + ("test4", ColumnType::Bool), + ("test5", ColumnType::String), + ("test6", ColumnType::Time), + ("test7", ColumnType::Tag), + ], + &[("test8", ColumnType::String), ("test9", ColumnType::Bool)], + ], + |res| assert_matches!(res, Ok(_)), + ) + .await; + + // Issue two calls with overlapping columns - request should succeed (upsert + // semantics). + test_column_create_or_get_many_unchecked_sub( + clean_state().await, + &[ + &[ + ("test1", ColumnType::I64), + ("test2", ColumnType::U64), + ("test3", ColumnType::F64), + ("test4", ColumnType::Bool), + ], + &[ + ("test1", ColumnType::I64), + ("test2", ColumnType::U64), + ("test3", ColumnType::F64), + ("test4", ColumnType::Bool), + ("test5", ColumnType::String), + ("test6", ColumnType::Time), + ("test7", ColumnType::Tag), + ("test8", ColumnType::String), + ], + ], + |res| assert_matches!(res, Ok(_)), + ) + .await; + + // Issue two calls with the same columns and types. + test_column_create_or_get_many_unchecked_sub( + clean_state().await, + &[ + &[ + ("test1", ColumnType::I64), + ("test2", ColumnType::U64), + ("test3", ColumnType::F64), + ("test4", ColumnType::Bool), + ], + &[ + ("test1", ColumnType::I64), + ("test2", ColumnType::U64), + ("test3", ColumnType::F64), + ("test4", ColumnType::Bool), + ], + ], + |res| assert_matches!(res, Ok(_)), + ) + .await; + + // Issue two calls with overlapping columns with conflicting types and + // observe a correctly populated ColumnTypeMismatch error. + test_column_create_or_get_many_unchecked_sub( + clean_state().await, + &[ + &[ + ("test1", ColumnType::String), + ("test2", ColumnType::String), + ("test3", ColumnType::String), + ("test4", ColumnType::String), + ], + &[ + ("test1", ColumnType::String), + ("test2", ColumnType::Bool), // This one differs + ("test3", ColumnType::String), + // 4 is missing. + ("test5", ColumnType::String), + ("test6", ColumnType::Time), + ("test7", ColumnType::Tag), + ("test8", ColumnType::String), + ], + ], + |res| assert_matches!(res, Err(e) => { + assert_matches!(e, Error::AlreadyExists { descr } => { + assert_eq!(descr, "column test2 is type string but schema update has type bool"); + }) + }), + ).await; +} + +async fn test_column_create_or_get_many_unchecked_sub( + catalog: Arc, + calls: &[&[(&'static str, ColumnType)]], + want: F, +) where + F: FnOnce(Result, Error>) + Send, +{ + let mut repos = catalog.repositories(); + + let namespace = arbitrary_namespace(&mut *repos, "ns4").await; + let table_id = arbitrary_table(&mut *repos, "table", &namespace).await.id; + + let mut last_got = None; + for insert in calls { + let insert = insert + .iter() + .map(|(n, t)| (*n, *t)) + .collect::>(); + + let got = repos + .columns() + .create_or_get_many_unchecked(table_id, insert.clone()) + .await; + + // The returned columns MUST always match the requested + // column values if successful. + if let Ok(got) = &got { + assert_eq!(insert.len(), got.len()); + + for got in got { + assert_eq!(table_id, got.table_id); + let requested_column_type = insert + .get(got.name.as_str()) + .expect("Should have gotten back a column that was inserted"); + assert_eq!(*requested_column_type, got.column_type,); + } + + assert_metric_hit(&catalog.metrics(), "column_create_or_get_many_unchecked"); + } + + last_got = Some(got); + } + + want(last_got.unwrap()); +} + +/// [`Catalog`] wrapper that is helpful for testing. +#[derive(Debug)] +pub(crate) struct TestCatalog { + hold_onto: Mutex>>, + inner: Arc, +} + +impl TestCatalog { + /// Create new test catalog. + pub(crate) fn new(inner: Arc) -> Self { + Self { + hold_onto: Mutex::new(vec![]), + inner, + } + } + + /// Hold onto given value til dropped. + pub(crate) fn hold_onto(&self, o: T) + where + T: Send + 'static, + { + self.hold_onto.lock().push(Box::new(o) as _) + } +} + +impl Display for TestCatalog { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "test({})", self.inner) + } +} + +#[async_trait] +impl Catalog for TestCatalog { + async fn setup(&self) -> Result<(), Error> { + self.inner.setup().await + } + + fn repositories(&self) -> Box { + self.inner.repositories() + } + + fn metrics(&self) -> Arc { + self.inner.metrics() + } + + fn time_provider(&self) -> Arc { + self.inner.time_provider() + } +} + +#[track_caller] +fn assert_gt(a: T, b: T) +where + T: Display + PartialOrd, +{ + assert!(a > b, "failed: {a} > {b}",); +} + +#[track_caller] +fn assert_ge(a: T, b: T) +where + T: Display + PartialOrd, +{ + assert!(a >= b, "failed: {a} >= {b}",); +} diff --git a/iox_catalog/src/kafkaless_transition.rs b/iox_catalog/src/kafkaless_transition.rs new file mode 100644 index 0000000..4216216 --- /dev/null +++ b/iox_catalog/src/kafkaless_transition.rs @@ -0,0 +1,95 @@ +/// Magic number to be used shard indices and shard ids in "kafkaless". +pub(crate) const TRANSITION_SHARD_NUMBER: i32 = 1234; +/// In kafkaless mode all new persisted data uses this shard id. +pub(crate) const TRANSITION_SHARD_ID: ShardId = ShardId::new(TRANSITION_SHARD_NUMBER as i64); +/// In kafkaless mode all new persisted data uses this shard index. +pub(crate) const TRANSITION_SHARD_INDEX: ShardIndex = ShardIndex::new(TRANSITION_SHARD_NUMBER); +pub(crate) const SHARED_TOPIC_NAME: &str = "iox-shared"; +pub(crate) const SHARED_TOPIC_ID: TopicId = TopicId::new(1); +pub(crate) const SHARED_QUERY_POOL_ID: QueryPoolId = QueryPoolId::new(1); +pub(crate) const SHARED_QUERY_POOL: &str = SHARED_TOPIC_NAME; + +/// Unique ID for a `Shard`, assigned by the catalog. Joins to other catalog tables to uniquely +/// identify shards independently of the underlying write buffer implementation. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type)] +#[sqlx(transparent)] +pub(crate) struct ShardId(i64); + +#[allow(missing_docs)] +impl ShardId { + pub(crate) const fn new(v: i64) -> Self { + Self(v) + } +} + +impl std::fmt::Display for ShardId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +/// The index of the shard in the set of shards. When Kafka is used as the write buffer, this is +/// the Kafka Partition ID. Used by the router and write buffer to shard requests to a particular +/// index in a set of shards. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type)] +#[sqlx(transparent)] +pub(crate) struct ShardIndex(i32); + +#[allow(missing_docs)] +impl ShardIndex { + pub(crate) const fn new(v: i32) -> Self { + Self(v) + } +} + +impl std::fmt::Display for ShardIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl std::str::FromStr for ShardIndex { + type Err = std::num::ParseIntError; + + fn from_str(s: &str) -> Result { + let v: i32 = s.parse()?; + Ok(Self(v)) + } +} + +/// Data object for a shard. Only one shard record can exist for a given topic and shard +/// index (enforced via uniqueness constraint). +#[derive(Debug, Copy, Clone, PartialEq, Eq, sqlx::FromRow)] +pub(crate) struct Shard { + /// the id of the shard, assigned by the catalog + pub(crate) id: ShardId, + /// the topic the shard is reading from + pub(crate) topic_id: TopicId, + /// the shard index of the shard the sequence numbers are coming from, sharded by the router + /// and write buffer + pub(crate) shard_index: ShardIndex, +} + +/// Unique ID for a Topic, assigned by the catalog +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type)] +#[sqlx(transparent)] +pub struct TopicId(i64); + +#[allow(missing_docs)] +impl TopicId { + pub const fn new(v: i64) -> Self { + Self(v) + } +} + +/// Unique ID for a `QueryPool` +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type)] +#[sqlx(transparent)] +pub struct QueryPoolId(i64); + +#[allow(missing_docs)] +impl QueryPoolId { + pub const fn new(v: i64) -> Self { + Self(v) + } +} diff --git a/iox_catalog/src/lib.rs b/iox_catalog/src/lib.rs new file mode 100644 index 0000000..17fa14f --- /dev/null +++ b/iox_catalog/src/lib.rs @@ -0,0 +1,35 @@ +//! The IOx catalog keeps track of the namespaces, tables, columns, parquet files, +//! and deletes in the system. Configuration information for distributing ingest, query +//! and compaction is also stored here. +#![deny(rustdoc::broken_intra_doc_links, rust_2018_idioms)] +#![warn( + missing_copy_implementations, + missing_debug_implementations, + missing_docs, + clippy::explicit_iter_loop, + // See https://github.com/influxdata/influxdb_iox/pull/1671 + clippy::future_not_send, + clippy::use_self, + clippy::clone_on_ref_ptr, + clippy::todo, + clippy::dbg_macro, + unused_crate_dependencies +)] + +// Workaround for "unused crate" lint false positives. +use workspace_hack as _; + +pub mod cache; +pub mod constants; +pub mod grpc; +pub mod interface; +pub mod mem; +pub mod metrics; +pub mod migrate; +pub mod postgres; +pub mod sqlite; +pub mod test_helpers; +pub mod util; + +#[cfg(test)] +pub(crate) mod interface_tests; diff --git a/iox_catalog/src/mem.rs b/iox_catalog/src/mem.rs new file mode 100644 index 0000000..0d810fd --- /dev/null +++ b/iox_catalog/src/mem.rs @@ -0,0 +1,1135 @@ +//! This module implements an in-memory implementation of the iox_catalog interface. It can be +//! used for testing or for an IOx designed to run without catalog persistence. + +use crate::{ + constants::{ + MAX_PARQUET_FILES_SELECTED_ONCE_FOR_DELETE, MAX_PARQUET_FILES_SELECTED_ONCE_FOR_RETENTION, + }, + interface::{ + AlreadyExistsSnafu, CasFailure, Catalog, ColumnRepo, Error, NamespaceRepo, ParquetFileRepo, + PartitionRepo, RepoCollection, Result, SoftDeletedRows, TableRepo, + }, + metrics::MetricDecorator, +}; +use async_trait::async_trait; +use data_types::snapshot::partition::PartitionSnapshot; +use data_types::snapshot::table::TableSnapshot; +use data_types::{ + partition_template::{ + NamespacePartitionTemplateOverride, TablePartitionTemplateOverride, TemplatePart, + }, + Column, ColumnId, ColumnType, CompactionLevel, MaxColumnsPerTable, MaxTables, Namespace, + NamespaceId, NamespaceName, NamespaceServiceProtectionLimitsOverride, ObjectStoreId, + ParquetFile, ParquetFileId, ParquetFileParams, Partition, PartitionHashId, PartitionId, + PartitionKey, SkippedCompaction, SortKeyIds, Table, TableId, Timestamp, +}; +use iox_time::TimeProvider; +use parking_lot::Mutex; +use snafu::ensure; +use std::ops::Deref; +use std::{ + collections::{HashMap, HashSet}, + fmt::{Display, Formatter}, + ops::DerefMut, + sync::Arc, +}; + +/// In-memory catalog that implements the `RepoCollection` and individual repo traits from +/// the catalog interface. +pub struct MemCatalog { + metrics: Arc, + collections: Arc>, + time_provider: Arc, +} + +impl MemCatalog { + /// return new initialized [`MemCatalog`] + pub fn new(metrics: Arc, time_provider: Arc) -> Self { + Self { + metrics, + collections: Default::default(), + time_provider, + } + } + + /// Add partition directly, for testing purposes only as it does not do any consistency or + /// uniqueness checks + pub fn add_partition(&self, partition: Partition) { + let mut stage = self.collections.lock(); + stage.partitions.push(partition.into()); + } +} + +impl std::fmt::Debug for MemCatalog { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MemCatalog").finish_non_exhaustive() + } +} + +/// A wrapper around `T` adding a generation number +#[derive(Debug, Clone)] +struct Versioned { + generation: u64, + value: T, +} + +impl Deref for Versioned { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.value + } +} + +impl DerefMut for Versioned { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.value + } +} + +impl From for Versioned { + fn from(value: T) -> Self { + Self { + generation: 0, + value, + } + } +} + +#[derive(Default, Debug, Clone)] +struct MemCollections { + namespaces: Vec, + tables: Vec>, + columns: Vec, + partitions: Vec>, + skipped_compactions: Vec, + parquet_files: Vec, +} + +/// transaction bound to an in-memory catalog. +#[derive(Debug)] +pub struct MemTxn { + collections: Arc>, + time_provider: Arc, +} + +impl Display for MemCatalog { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Memory") + } +} + +#[async_trait] +impl Catalog for MemCatalog { + async fn setup(&self) -> Result<(), Error> { + Ok(()) + } + + fn repositories(&self) -> Box { + let collections = Arc::clone(&self.collections); + Box::new(MetricDecorator::new( + MemTxn { + collections, + time_provider: self.time_provider(), + }, + Arc::clone(&self.metrics), + self.time_provider(), + )) + } + + #[cfg(test)] + fn metrics(&self) -> Arc { + Arc::clone(&self.metrics) + } + + fn time_provider(&self) -> Arc { + Arc::clone(&self.time_provider) + } +} + +impl RepoCollection for MemTxn { + fn namespaces(&mut self) -> &mut dyn NamespaceRepo { + self + } + + fn tables(&mut self) -> &mut dyn TableRepo { + self + } + + fn columns(&mut self) -> &mut dyn ColumnRepo { + self + } + + fn partitions(&mut self) -> &mut dyn PartitionRepo { + self + } + + fn parquet_files(&mut self) -> &mut dyn ParquetFileRepo { + self + } +} + +#[async_trait] +impl NamespaceRepo for MemTxn { + async fn create( + &mut self, + name: &NamespaceName<'_>, + partition_template: Option, + retention_period_ns: Option, + service_protection_limits: Option, + ) -> Result { + let mut stage = self.collections.lock(); + + if stage.namespaces.iter().any(|n| n.name == name.as_str()) { + return Err(Error::AlreadyExists { + descr: name.to_string(), + }); + } + + let max_tables = service_protection_limits + .and_then(|l| l.max_tables) + .unwrap_or_default(); + let max_columns_per_table = service_protection_limits + .and_then(|l| l.max_columns_per_table) + .unwrap_or_default(); + + let namespace = Namespace { + id: NamespaceId::new(stage.namespaces.len() as i64 + 1), + name: name.to_string(), + max_tables, + max_columns_per_table, + retention_period_ns, + deleted_at: None, + partition_template: partition_template.unwrap_or_default(), + }; + stage.namespaces.push(namespace); + Ok(stage.namespaces.last().unwrap().clone()) + } + + async fn list(&mut self, deleted: SoftDeletedRows) -> Result> { + let stage = self.collections.lock(); + + Ok(filter_namespace_soft_delete(&stage.namespaces, deleted) + .cloned() + .collect()) + } + + async fn get_by_id( + &mut self, + id: NamespaceId, + deleted: SoftDeletedRows, + ) -> Result> { + let stage = self.collections.lock(); + + let res = filter_namespace_soft_delete(&stage.namespaces, deleted) + .find(|n| n.id == id) + .cloned(); + + Ok(res) + } + + async fn get_by_name( + &mut self, + name: &str, + deleted: SoftDeletedRows, + ) -> Result> { + let stage = self.collections.lock(); + + let res = filter_namespace_soft_delete(&stage.namespaces, deleted) + .find(|n| n.name == name) + .cloned(); + + Ok(res) + } + + // performs a cascading delete of all things attached to the namespace, then deletes the + // namespace + async fn soft_delete(&mut self, name: &str) -> Result<()> { + let mut stage = self.collections.lock(); + let timestamp = self.time_provider.now(); + // get namespace by name + match stage.namespaces.iter_mut().find(|n| n.name == name) { + Some(n) => { + n.deleted_at = Some(Timestamp::from(timestamp)); + Ok(()) + } + None => Err(Error::NotFound { + descr: name.to_string(), + }), + } + } + + async fn update_table_limit(&mut self, name: &str, new_max: MaxTables) -> Result { + let mut stage = self.collections.lock(); + match stage.namespaces.iter_mut().find(|n| n.name == name) { + Some(n) => { + n.max_tables = new_max; + Ok(n.clone()) + } + None => Err(Error::NotFound { + descr: name.to_string(), + }), + } + } + + async fn update_column_limit( + &mut self, + name: &str, + new_max: MaxColumnsPerTable, + ) -> Result { + let mut stage = self.collections.lock(); + match stage.namespaces.iter_mut().find(|n| n.name == name) { + Some(n) => { + n.max_columns_per_table = new_max; + Ok(n.clone()) + } + None => Err(Error::NotFound { + descr: name.to_string(), + }), + } + } + + async fn update_retention_period( + &mut self, + name: &str, + retention_period_ns: Option, + ) -> Result { + let mut stage = self.collections.lock(); + match stage.namespaces.iter_mut().find(|n| n.name == name) { + Some(n) => { + n.retention_period_ns = retention_period_ns; + Ok(n.clone()) + } + None => Err(Error::NotFound { + descr: name.to_string(), + }), + } + } +} + +#[async_trait] +impl TableRepo for MemTxn { + async fn create( + &mut self, + name: &str, + partition_template: TablePartitionTemplateOverride, + namespace_id: NamespaceId, + ) -> Result
{ + let mut stage = self.collections.lock(); + + let table = { + // this block is just to ensure the mem impl correctly creates TableCreateLimitError in + // tests, we don't care about any of the errors it is discarding + stage + .namespaces + .iter() + .find(|n| n.id == namespace_id) + .cloned() + .ok_or_else(|| Error::NotFound { + // we're never going to use this error, this is just for flow control, + // so it doesn't matter that we only have the ID, not the name + descr: "".to_string(), + }) + .and_then(|n| { + let max_tables = n.max_tables; + let tables_count = stage + .tables + .iter() + .filter(|t| t.namespace_id == namespace_id) + .count(); + if tables_count >= max_tables.get() { + return Err(Error::LimitExceeded { + descr: format!( + "couldn't create table {}; limit reached on namespace {}", + name, namespace_id + ), + }); + } + Ok(()) + })?; + + match stage + .tables + .iter() + .find(|t| t.name == name && t.namespace_id == namespace_id) + { + Some(_t) => { + return Err(Error::AlreadyExists { + descr: format!("table '{name}' in namespace {namespace_id}"), + }) + } + None => { + let table = Table { + id: TableId::new(stage.tables.len() as i64 + 1), + namespace_id, + name: name.to_string(), + partition_template, + }; + stage.tables.push(table.into()); + stage.tables.last().unwrap().value.clone() + } + } + }; + + // Partitioning is only supported for tags, so create tag columns for all `TagValue` + // partition template parts. It's important this happens within the table creation + // transaction so that there isn't a possibility of a concurrent write creating these + // columns with an unsupported type. + for template_part in table.partition_template.parts() { + if let TemplatePart::TagValue(tag_name) = template_part { + create_or_get_column(&mut stage, tag_name, table.id, ColumnType::Tag)?; + } + } + + Ok(table) + } + + async fn get_by_id(&mut self, table_id: TableId) -> Result> { + let stage = self.collections.lock(); + + let mut tables = stage.tables.iter(); + Ok(tables.find(|t| t.id == table_id).map(|v| v.value.clone())) + } + + async fn get_by_namespace_and_name( + &mut self, + namespace_id: NamespaceId, + name: &str, + ) -> Result> { + let stage = self.collections.lock(); + + let mut tables = stage.tables.iter(); + let search = tables.find(|t| t.namespace_id == namespace_id && t.name == name); + Ok(search.map(|v| v.value.clone())) + } + + async fn list_by_namespace_id(&mut self, namespace_id: NamespaceId) -> Result> { + let stage = self.collections.lock(); + + let tables = stage.tables.iter(); + let filtered = tables.filter(|t| t.namespace_id == namespace_id); + let tables: Vec<_> = filtered.map(|v| v.value.clone()).collect(); + Ok(tables) + } + + async fn list(&mut self) -> Result> { + let stage = self.collections.lock(); + Ok(stage.tables.iter().map(|v| v.value.clone()).collect()) + } + + async fn snapshot(&mut self, table_id: TableId) -> Result { + let mut guard = self.collections.lock(); + + let (table, generation) = { + let mut tables = guard.tables.iter_mut(); + let search = tables.find(|x| x.id == table_id); + let table = search.ok_or_else(|| Error::NotFound { + descr: table_id.to_string(), + })?; + + let generation = table.generation; + table.generation += 1; + (table.value.clone(), generation) + }; + + let columns = guard + .columns + .iter() + .filter(|x| x.table_id == table_id) + .cloned() + .collect(); + + let partitions = guard + .partitions + .iter() + .filter(|x| x.table_id == table_id) + .map(|v| v.value.clone()) + .collect(); + + Ok(TableSnapshot::encode( + table, partitions, columns, generation, + )?) + } +} + +#[async_trait] +impl ColumnRepo for MemTxn { + async fn create_or_get( + &mut self, + name: &str, + table_id: TableId, + column_type: ColumnType, + ) -> Result { + let mut stage = self.collections.lock(); + create_or_get_column(&mut stage, name, table_id, column_type) + } + + async fn create_or_get_many_unchecked( + &mut self, + table_id: TableId, + columns: HashMap<&str, ColumnType>, + ) -> Result> { + // Explicitly NOT using `create_or_get` in this function: the Postgres catalog doesn't + // check column limits when inserting many columns because it's complicated and expensive, + // and for testing purposes the in-memory catalog needs to match its functionality. + + let mut stage = self.collections.lock(); + + let out: Vec<_> = columns + .iter() + .map(|(&column_name, &column_type)| { + match stage + .columns + .iter() + .find(|t| t.name == column_name && t.table_id == table_id) + { + Some(c) => { + ensure!( + column_type == c.column_type, + AlreadyExistsSnafu { + descr: format!( + "column {} is type {} but schema update has type {}", + column_name, c.column_type, column_type + ), + } + ); + Ok(c.clone()) + } + None => { + let new_column = Column { + id: ColumnId::new(stage.columns.len() as i64 + 1), + table_id, + name: column_name.to_string(), + column_type, + }; + stage.columns.push(new_column); + Ok(stage.columns.last().unwrap().clone()) + } + } + }) + .collect::>>()?; + + Ok(out) + } + + async fn list_by_namespace_id(&mut self, namespace_id: NamespaceId) -> Result> { + let stage = self.collections.lock(); + + let table_ids: Vec<_> = stage + .tables + .iter() + .filter(|t| t.namespace_id == namespace_id) + .map(|t| t.id) + .collect(); + let columns: Vec<_> = stage + .columns + .iter() + .filter(|c| table_ids.contains(&c.table_id)) + .cloned() + .collect(); + + Ok(columns) + } + + async fn list_by_table_id(&mut self, table_id: TableId) -> Result> { + let stage = self.collections.lock(); + + let columns: Vec<_> = stage + .columns + .iter() + .filter(|c| c.table_id == table_id) + .cloned() + .collect(); + + Ok(columns) + } + + async fn list(&mut self) -> Result> { + let stage = self.collections.lock(); + Ok(stage.columns.clone()) + } +} + +#[async_trait] +impl PartitionRepo for MemTxn { + async fn create_or_get(&mut self, key: PartitionKey, table_id: TableId) -> Result { + let mut stage = self.collections.lock(); + + let partition = match stage + .partitions + .iter() + .find(|p| p.partition_key == key && p.table_id == table_id) + { + Some(p) => p, + None => { + let hash_id = PartitionHashId::new(table_id, &key); + let p = Partition::new_catalog_only( + PartitionId::new(stage.partitions.len() as i64 + 1), + Some(hash_id), + table_id, + key, + SortKeyIds::default(), + None, + ); + stage.partitions.push(p.into()); + stage.partitions.last().unwrap() + } + }; + + Ok(partition.value.clone()) + } + + async fn get_by_id_batch(&mut self, partition_ids: &[PartitionId]) -> Result> { + let lookup = partition_ids.iter().collect::>(); + + let stage = self.collections.lock(); + + Ok(stage + .partitions + .iter() + .filter(|p| lookup.contains(&p.id)) + .map(|x| x.value.clone()) + .collect()) + } + + async fn list_by_table_id(&mut self, table_id: TableId) -> Result> { + let stage = self.collections.lock(); + + let partitions: Vec<_> = stage + .partitions + .iter() + .filter(|p| p.table_id == table_id) + .map(|x| x.value.clone()) + .collect(); + Ok(partitions) + } + + async fn list_ids(&mut self) -> Result> { + let stage = self.collections.lock(); + + let partitions: Vec<_> = stage.partitions.iter().map(|p| p.id).collect(); + + Ok(partitions) + } + + async fn cas_sort_key( + &mut self, + partition_id: PartitionId, + old_sort_key_ids: Option<&SortKeyIds>, + new_sort_key_ids: &SortKeyIds, + ) -> Result> { + let mut stage = self.collections.lock(); + + match stage.partitions.iter_mut().find(|p| p.id == partition_id) { + Some(p) if p.sort_key_ids() == old_sort_key_ids => { + p.set_sort_key_ids(new_sort_key_ids); + Ok(p.value.clone()) + } + Some(p) => { + return Err(CasFailure::ValueMismatch( + p.sort_key_ids().cloned().unwrap_or_default(), + )); + } + None => Err(CasFailure::QueryError(Error::NotFound { + descr: partition_id.to_string(), + })), + } + } + + async fn record_skipped_compaction( + &mut self, + partition_id: PartitionId, + reason: &str, + num_files: usize, + limit_num_files: usize, + limit_num_files_first_in_partition: usize, + estimated_bytes: u64, + limit_bytes: u64, + ) -> Result<()> { + let mut stage = self.collections.lock(); + + let reason = reason.to_string(); + let skipped_at = Timestamp::from(self.time_provider.now()); + + let sc = SkippedCompaction { + partition_id, + reason, + skipped_at, + num_files: num_files as i64, + limit_num_files: limit_num_files as i64, + limit_num_files_first_in_partition: limit_num_files_first_in_partition as i64, + estimated_bytes: estimated_bytes as i64, + limit_bytes: limit_bytes as i64, + }; + + match stage + .skipped_compactions + .iter_mut() + .find(|s| s.partition_id == partition_id) + { + Some(s) => { + *s = sc; + } + None => stage.skipped_compactions.push(sc), + } + Ok(()) + } + + async fn get_in_skipped_compactions( + &mut self, + partition_ids: &[PartitionId], + ) -> Result> { + let stage = self.collections.lock(); + let find: HashSet<&PartitionId> = partition_ids.iter().collect(); + Ok(stage + .skipped_compactions + .iter() + .filter(|s| find.contains(&s.partition_id)) + .cloned() + .collect()) + } + + async fn list_skipped_compactions(&mut self) -> Result> { + let stage = self.collections.lock(); + Ok(stage.skipped_compactions.clone()) + } + + async fn delete_skipped_compactions( + &mut self, + partition_id: PartitionId, + ) -> Result> { + use std::mem; + + let mut stage = self.collections.lock(); + let skipped_compactions = mem::take(&mut stage.skipped_compactions); + let (mut removed, remaining) = skipped_compactions + .into_iter() + .partition(|sc| sc.partition_id == partition_id); + stage.skipped_compactions = remaining; + + match removed.pop() { + Some(sc) if removed.is_empty() => Ok(Some(sc)), + Some(_) => unreachable!("There must be exactly one skipped compaction per partition"), + None => Ok(None), + } + } + + async fn most_recent_n(&mut self, n: usize) -> Result> { + let stage = self.collections.lock(); + let iter = stage.partitions.iter().rev().take(n); + Ok(iter.map(|x| x.value.clone()).collect()) + } + + async fn partitions_new_file_between( + &mut self, + minimum_time: Timestamp, + maximum_time: Option, + ) -> Result> { + let stage = self.collections.lock(); + + let partitions: Vec<_> = stage + .partitions + .iter() + .filter(|p| { + p.new_file_at > Some(minimum_time) + && maximum_time + .map(|max| p.new_file_at < Some(max)) + .unwrap_or(true) + }) + .map(|p| p.id) + .collect(); + + Ok(partitions) + } + + async fn list_old_style(&mut self) -> Result> { + let stage = self.collections.lock(); + + let old_style: Vec<_> = stage + .partitions + .iter() + .filter(|p| p.hash_id().is_none()) + .map(|x| x.value.clone()) + .collect(); + + Ok(old_style) + } + + async fn snapshot(&mut self, partition_id: PartitionId) -> Result { + let mut guard = self.collections.lock(); + let (partition, generation) = { + let search = guard.partitions.iter_mut().find(|x| x.id == partition_id); + let partition = search.ok_or_else(|| Error::NotFound { + descr: format!("Partition {partition_id} not found"), + })?; + + let generation = partition.generation; + partition.generation += 1; + (partition.value.clone(), generation) + }; + + let files = guard + .parquet_files + .iter() + .filter(|x| x.partition_id == partition_id && x.to_delete.is_none()) + .cloned() + .collect(); + + let search = guard.tables.iter().find(|x| x.id == partition.table_id); + let table = search.ok_or_else(|| Error::NotFound { + descr: format!("Table {} not found", partition.table_id), + })?; + + let sc = guard + .skipped_compactions + .iter() + .find(|sc| sc.partition_id == partition_id) + .cloned(); + + Ok(PartitionSnapshot::encode( + table.namespace_id, + partition, + files, + sc, + generation, + )?) + } +} + +#[async_trait] +impl ParquetFileRepo for MemTxn { + async fn flag_for_delete_by_retention(&mut self) -> Result> { + let mut stage = self.collections.lock(); + let now = Timestamp::from(self.time_provider.now()); + let stage = stage.deref_mut(); + + Ok(stage + .parquet_files + .iter_mut() + // don't flag if already flagged for deletion + .filter(|f| f.to_delete.is_none()) + .filter_map(|f| { + // table retention, if it exists, overrides namespace retention + // TODO - include check of table retention period once implemented + stage + .namespaces + .iter() + .find(|n| n.id == f.namespace_id) + .and_then(|ns| { + ns.retention_period_ns.and_then(|rp| { + if f.max_time < now - rp { + f.to_delete = Some(now); + Some((f.partition_id, f.object_store_id)) + } else { + None + } + }) + }) + }) + .take(MAX_PARQUET_FILES_SELECTED_ONCE_FOR_RETENTION as usize) + .collect()) + } + + async fn delete_old_ids_only(&mut self, older_than: Timestamp) -> Result> { + let mut stage = self.collections.lock(); + + let (delete, keep): (Vec<_>, Vec<_>) = stage.parquet_files.iter().cloned().partition( + |f| matches!(f.to_delete, Some(marked_deleted) if marked_deleted < older_than), + ); + + stage.parquet_files = keep; + + let delete = delete + .into_iter() + .take(MAX_PARQUET_FILES_SELECTED_ONCE_FOR_DELETE as usize) + .map(|f| f.object_store_id) + .collect(); + Ok(delete) + } + + async fn list_by_partition_not_to_delete_batch( + &mut self, + partition_ids: Vec, + ) -> Result> { + let partition_ids = partition_ids.into_iter().collect::>(); + let stage = self.collections.lock(); + + Ok(stage + .parquet_files + .iter() + .filter(|f| partition_ids.contains(&f.partition_id) && f.to_delete.is_none()) + .cloned() + .collect()) + } + + async fn get_by_object_store_id( + &mut self, + object_store_id: ObjectStoreId, + ) -> Result> { + let stage = self.collections.lock(); + + Ok(stage + .parquet_files + .iter() + .find(|f| f.object_store_id.eq(&object_store_id)) + .cloned()) + } + + async fn exists_by_object_store_id_batch( + &mut self, + object_store_ids: Vec, + ) -> Result> { + let stage = self.collections.lock(); + + Ok(stage + .parquet_files + .iter() + .filter(|f| object_store_ids.contains(&f.object_store_id)) + .map(|f| f.object_store_id) + .collect()) + } + + async fn create_upgrade_delete( + &mut self, + partition_id: PartitionId, + delete: &[ObjectStoreId], + upgrade: &[ObjectStoreId], + create: &[ParquetFileParams], + target_level: CompactionLevel, + ) -> Result> { + let delete_set = delete.iter().copied().collect::>(); + let upgrade_set = upgrade.iter().copied().collect::>(); + + assert!( + delete_set.is_disjoint(&upgrade_set), + "attempted to upgrade a file scheduled for delete" + ); + + let mut collections = self.collections.lock(); + let mut stage = collections.clone(); + + for id in delete { + let marked_at = Timestamp::from(self.time_provider.now()); + flag_for_delete(&mut stage, partition_id, *id, marked_at)?; + } + + update_compaction_level(&mut stage, partition_id, upgrade, target_level)?; + + let mut ids = Vec::with_capacity(create.len()); + for file in create { + if file.partition_id != partition_id { + return Err(Error::External { + source: format!("Inconsistent ParquetFileParams, expected PartitionId({partition_id}) got PartitionId({})", file.partition_id).into(), + }); + } + let res = create_parquet_file(&mut stage, file.clone())?; + ids.push(res.id); + } + + *collections = stage; + + Ok(ids) + } +} + +fn filter_namespace_soft_delete<'a>( + v: impl IntoIterator, + deleted: SoftDeletedRows, +) -> impl Iterator { + v.into_iter().filter(move |v| match deleted { + SoftDeletedRows::AllRows => true, + SoftDeletedRows::ExcludeDeleted => v.deleted_at.is_none(), + SoftDeletedRows::OnlyDeleted => v.deleted_at.is_some(), + }) +} + +fn create_or_get_column( + stage: &mut MemCollections, + name: &str, + table_id: TableId, + column_type: ColumnType, +) -> Result { + // this block is just to ensure the mem impl correctly creates ColumnCreateLimitError in + // tests, we don't care about any of the errors it is discarding + stage + .tables + .iter() + .find(|t| t.id == table_id) + .cloned() + .ok_or(Error::NotFound { + descr: format!("table: {}", table_id), + }) // error never used, this is just for flow control + .and_then(|t| { + stage + .namespaces + .iter() + .find(|n| n.id == t.namespace_id) + .cloned() + .ok_or_else(|| Error::NotFound { + // we're never going to use this error, this is just for flow control, + // so it doesn't matter that we only have the ID, not the name + descr: "".to_string(), + }) + .and_then(|n| { + let max_columns_per_table = n.max_columns_per_table; + let columns_count = stage + .columns + .iter() + .filter(|t| t.table_id == table_id) + .count(); + if columns_count >= max_columns_per_table.get() { + return Err(Error::LimitExceeded { + descr: format!( + "couldn't create column {} in table {}; limit reached on namespace", + name, table_id + ), + }); + } + Ok(()) + })?; + Ok(()) + })?; + + let column = match stage + .columns + .iter() + .find(|t| t.name == name && t.table_id == table_id) + { + Some(c) => { + ensure!( + column_type == c.column_type, + AlreadyExistsSnafu { + descr: format!( + "column {} is type {} but schema update has type {}", + name, c.column_type, column_type + ), + } + ); + c + } + None => { + let column = Column { + id: ColumnId::new(stage.columns.len() as i64 + 1), + table_id, + name: name.to_string(), + column_type, + }; + stage.columns.push(column); + stage.columns.last().unwrap() + } + }; + + Ok(column.clone()) +} + +// The following three functions are helpers to the create_upgrade_delete method. +// They are also used by the respective create/flag_for_delete/update_compaction_level methods. +fn create_parquet_file( + stage: &mut MemCollections, + parquet_file_params: ParquetFileParams, +) -> Result { + if stage + .parquet_files + .iter() + .any(|f| f.object_store_id == parquet_file_params.object_store_id) + { + return Err(Error::AlreadyExists { + descr: parquet_file_params.object_store_id.to_string(), + }); + } + + let parquet_file = ParquetFile::from_params( + parquet_file_params, + ParquetFileId::new(stage.parquet_files.len() as i64 + 1), + ); + let created_at = parquet_file.created_at; + let partition_id = parquet_file.partition_id; + stage.parquet_files.push(parquet_file); + + // Update the new_file_at field its partition to the time of created_at + let partition = stage + .partitions + .iter_mut() + .find(|p| p.id == partition_id) + .ok_or(Error::NotFound { + descr: partition_id.to_string(), + })?; + partition.new_file_at = Some(created_at); + + Ok(stage.parquet_files.last().unwrap().clone()) +} + +fn flag_for_delete( + stage: &mut MemCollections, + partition_id: PartitionId, + id: ObjectStoreId, + marked_at: Timestamp, +) -> Result<()> { + match stage + .parquet_files + .iter_mut() + .find(|p| p.object_store_id == id && p.partition_id == partition_id) + { + Some(f) if f.to_delete.is_none() => f.to_delete = Some(marked_at), + _ => { + return Err(Error::NotFound { + descr: format!("parquet file {id} not found for delete"), + }) + } + } + + Ok(()) +} + +fn update_compaction_level( + stage: &mut MemCollections, + partition_id: PartitionId, + object_store_ids: &[ObjectStoreId], + compaction_level: CompactionLevel, +) -> Result> { + let all_ids = stage + .parquet_files + .iter() + .filter(|f| f.partition_id == partition_id && f.to_delete.is_none()) + .map(|f| f.object_store_id) + .collect::>(); + for id in object_store_ids { + if !all_ids.contains(id) { + return Err(Error::NotFound { + descr: format!("parquet file {id} not found for upgrade"), + }); + } + } + + let update_ids = object_store_ids.iter().copied().collect::>(); + let mut updated = Vec::with_capacity(object_store_ids.len()); + for f in stage + .parquet_files + .iter_mut() + .filter(|p| update_ids.contains(&p.object_store_id) && p.partition_id == partition_id) + { + f.compaction_level = compaction_level; + updated.push(f.object_store_id); + } + + Ok(updated) +} + +#[cfg(test)] +mod tests { + use iox_time::SystemProvider; + + use super::*; + use std::sync::Arc; + + #[tokio::test] + async fn test_catalog() { + crate::interface_tests::test_catalog(|| async { + let metrics = Arc::new(metric::Registry::default()); + let time_provider = Arc::new(SystemProvider::new()); + let x: Arc = Arc::new(MemCatalog::new(metrics, time_provider)); + x + }) + .await; + } +} diff --git a/iox_catalog/src/metrics.rs b/iox_catalog/src/metrics.rs new file mode 100644 index 0000000..b179fd3 --- /dev/null +++ b/iox_catalog/src/metrics.rs @@ -0,0 +1,203 @@ +//! Metric instrumentation for catalog implementations. + +use crate::interface::{ + CasFailure, ColumnRepo, NamespaceRepo, ParquetFileRepo, PartitionRepo, RepoCollection, Result, + SoftDeletedRows, TableRepo, +}; +use async_trait::async_trait; +use data_types::snapshot::table::TableSnapshot; +use data_types::{ + partition_template::{NamespacePartitionTemplateOverride, TablePartitionTemplateOverride}, + snapshot::partition::PartitionSnapshot, + Column, ColumnType, CompactionLevel, MaxColumnsPerTable, MaxTables, Namespace, NamespaceId, + NamespaceName, NamespaceServiceProtectionLimitsOverride, ObjectStoreId, ParquetFile, + ParquetFileId, ParquetFileParams, Partition, PartitionId, PartitionKey, SkippedCompaction, + SortKeyIds, Table, TableId, Timestamp, +}; +use iox_time::TimeProvider; +use metric::{DurationHistogram, Metric}; +use std::{collections::HashMap, fmt::Debug, sync::Arc}; + +/// Decorates a implementation of the catalog's [`RepoCollection`] (and the +/// transactional variant) with instrumentation that emits latency histograms +/// for each method. +/// +/// Values are recorded under the `catalog_op_duration` metric, labelled by +/// operation name and result (success/error). +#[derive(Debug)] +pub struct MetricDecorator { + inner: T, + time_provider: Arc, + metrics: Arc, +} + +impl MetricDecorator { + /// Wrap `T` with instrumentation recording operation latency in `metrics`. + pub fn new( + inner: T, + metrics: Arc, + time_provider: Arc, + ) -> Self { + Self { + inner, + time_provider, + metrics, + } + } +} + +impl RepoCollection for MetricDecorator +where + T: NamespaceRepo + TableRepo + ColumnRepo + PartitionRepo + ParquetFileRepo + Debug, +{ + fn namespaces(&mut self) -> &mut dyn NamespaceRepo { + self + } + + fn tables(&mut self) -> &mut dyn TableRepo { + self + } + + fn columns(&mut self) -> &mut dyn ColumnRepo { + self + } + + fn partitions(&mut self) -> &mut dyn PartitionRepo { + self + } + + fn parquet_files(&mut self) -> &mut dyn ParquetFileRepo { + self + } +} + +/// Emit a trait impl for `impl_trait` that delegates calls to the inner +/// implementation, recording the duration and result to the metrics registry. +/// +/// Format: +/// +/// ```ignore +/// decorate!( +/// impl_trait = , +/// methods = [ +/// "" = ; +/// "" = ; +/// // ... and so on +/// ] +/// ); +/// ``` +/// +/// All methods of a given trait MUST be defined in the `decorate!()` call so +/// they are all instrumented or the decorator will not compile as it won't +/// fully implement the trait. +macro_rules! decorate { + ( + impl_trait = $trait:ident, + methods = [$( + $metric:literal = $method:ident( + &mut self $(,)? + $($arg:ident : $t:ty),* + ) -> Result<$out:ty$(, $err:ty)?>; + )+] + ) => { + #[async_trait] + impl $trait for MetricDecorator { + /// NOTE: if you're seeing an error here about "not all trait items + /// implemented" or something similar, one or more methods are + /// missing from / incorrectly defined in the decorate!() blocks + /// below. + + $( + async fn $method(&mut self, $($arg : $t),*) -> Result<$out$(, $err)?> { + let observer: Metric = self.metrics.register_metric( + "catalog_op_duration", + "catalog call duration", + ); + + let t = self.time_provider.now(); + let res = self.inner.$method($($arg),*).await; + + // Avoid exploding if time goes backwards - simply drop the + // measurement if it happens. + if let Some(delta) = self.time_provider.now().checked_duration_since(t) { + let tag = match &res { + Ok(_) => "success", + Err(_) => "error", + }; + observer.recorder(&[("op", $metric), ("result", tag)]).record(delta); + } + + res + } + )+ + } + }; +} + +decorate!( + impl_trait = NamespaceRepo, + methods = [ + "namespace_create" = create(&mut self, name: &NamespaceName<'_>, partition_template: Option, retention_period_ns: Option, service_protection_limits: Option) -> Result; + "namespace_update_retention_period" = update_retention_period(&mut self, name: &str, retention_period_ns: Option) -> Result; + "namespace_list" = list(&mut self, deleted: SoftDeletedRows) -> Result>; + "namespace_get_by_id" = get_by_id(&mut self, id: NamespaceId, deleted: SoftDeletedRows) -> Result>; + "namespace_get_by_name" = get_by_name(&mut self, name: &str, deleted: SoftDeletedRows) -> Result>; + "namespace_soft_delete" = soft_delete(&mut self, name: &str) -> Result<()>; + "namespace_update_table_limit" = update_table_limit(&mut self, name: &str, new_max: MaxTables) -> Result; + "namespace_update_column_limit" = update_column_limit(&mut self, name: &str, new_max: MaxColumnsPerTable) -> Result; + ] +); + +decorate!( + impl_trait = TableRepo, + methods = [ + "table_create" = create(&mut self, name: &str, partition_template: TablePartitionTemplateOverride, namespace_id: NamespaceId) -> Result
; + "table_get_by_id" = get_by_id(&mut self, table_id: TableId) -> Result>; + "table_get_by_namespace_and_name" = get_by_namespace_and_name(&mut self, namespace_id: NamespaceId, name: &str) -> Result>; + "table_list_by_namespace_id" = list_by_namespace_id(&mut self, namespace_id: NamespaceId) -> Result>; + "table_list" = list(&mut self) -> Result>; + "table_snapshot" = snapshot(&mut self, table_id: TableId) -> Result; + ] +); + +decorate!( + impl_trait = ColumnRepo, + methods = [ + "column_create_or_get" = create_or_get(&mut self, name: &str, table_id: TableId, column_type: ColumnType) -> Result; + "column_list_by_namespace_id" = list_by_namespace_id(&mut self, namespace_id: NamespaceId) -> Result>; + "column_list_by_table_id" = list_by_table_id(&mut self, table_id: TableId) -> Result>; + "column_create_or_get_many_unchecked" = create_or_get_many_unchecked(&mut self, table_id: TableId, columns: HashMap<&str, ColumnType>) -> Result>; + "column_list" = list(&mut self) -> Result>; + ] +); + +decorate!( + impl_trait = PartitionRepo, + methods = [ + "partition_create_or_get" = create_or_get(&mut self, key: PartitionKey, table_id: TableId) -> Result; + "partition_get_by_id_batch" = get_by_id_batch(&mut self, partition_ids: &[PartitionId]) -> Result>; + "partition_list_by_table_id" = list_by_table_id(&mut self, table_id: TableId) -> Result>; + "partition_list_ids" = list_ids(&mut self) -> Result>; + "partition_update_sort_key" = cas_sort_key(&mut self, partition_id: PartitionId, old_sort_key_ids: Option<&SortKeyIds>, new_sort_key_ids: &SortKeyIds) -> Result>; + "partition_record_skipped_compaction" = record_skipped_compaction(&mut self, partition_id: PartitionId, reason: &str, num_files: usize, limit_num_files: usize, limit_num_files_first_in_partition: usize, estimated_bytes: u64, limit_bytes: u64) -> Result<()>; + "partition_list_skipped_compactions" = list_skipped_compactions(&mut self) -> Result>; + "partition_delete_skipped_compactions" = delete_skipped_compactions(&mut self, partition_id: PartitionId) -> Result>; + "partition_most_recent_n" = most_recent_n(&mut self, n: usize) -> Result>; + "partition_partitions_new_file_between" = partitions_new_file_between(&mut self, minimum_time: Timestamp, maximum_time: Option) -> Result>; + "partition_get_in_skipped_compactions" = get_in_skipped_compactions(&mut self, partition_ids: &[PartitionId]) -> Result>; + "partition_list_old_style" = list_old_style(&mut self) -> Result>; + "partition_snapshot" = snapshot(&mut self, partition_id: PartitionId) -> Result; + ] +); + +decorate!( + impl_trait = ParquetFileRepo, + methods = [ + "parquet_flag_for_delete_by_retention" = flag_for_delete_by_retention(&mut self) -> Result>; + "parquet_delete_old_ids_only" = delete_old_ids_only(&mut self, older_than: Timestamp) -> Result>; + "parquet_list_by_partition_not_to_delete_batch" = list_by_partition_not_to_delete_batch(&mut self, partition_ids: Vec) -> Result>; + "parquet_get_by_object_store_id" = get_by_object_store_id(&mut self, object_store_id: ObjectStoreId) -> Result>; + "parquet_exists_by_object_store_id_batch" = exists_by_object_store_id_batch(&mut self, object_store_ids: Vec) -> Result>; + "parquet_create_upgrade_delete" = create_upgrade_delete(&mut self, partition_id: PartitionId, delete: &[ObjectStoreId], upgrade: &[ObjectStoreId], create: &[ParquetFileParams], target_level: CompactionLevel) -> Result>; + ] +); diff --git a/iox_catalog/src/migrate.rs b/iox_catalog/src/migrate.rs new file mode 100644 index 0000000..5bbf963 --- /dev/null +++ b/iox_catalog/src/migrate.rs @@ -0,0 +1,2437 @@ +//! Better migrations. +//! +//! # Why +//! +//! SQLx migrations don't work for us, see: +//! +//! - +//! - +//! +//! # Usage +//! +//! Just place your migration in the `migrations` folder. They basically work like normal SQLx migrations but there are +//! a few extra, magic comments you can put in your code to modify the behavior. +//! +//! ## Steps +//! +//! The entire SQL text will be executed as a single statement. However, you can split it into multiple steps by using +//! a marker: +//! +//! ```sql +//! CREATE TABLE t1 (x INT); +//! +//! -- IOX_STEP_BOUNDARY +//! +//! CREATE TABLE t2 (x INT); +//! ``` +//! +//! ## Transactions & Idempotency +//! +//! All steps will be executed within one transaction. However, you can opt-out of this: +//! +//! ```sql +//! -- this step is wrapped in a transaction +//! CREATE TABLE t1 (x INT); +//! +//! -- IOX_STEP_BOUNDARY +//! +//! -- this step isn't +//! -- IOX_NO_TRANSACTION +//! CREATE TABLE t2 (x INT); +//! ``` +//! +//! If all steps can be run in a transaction, the entire migration (including its bookkeeping) will be executed in a +//! transaction. In this case, the transaction is automatically idempotent. +//! +//! Migrations that opt out of the transaction handling MUST ensure that they are idempotent. This also includes that +//! they end up in the desired target state even if they were interrupted midway in a previous run. +//! +//! ## Updating / Fixing Migrations +//! +//! **⚠️ In general a migration MUST NOT be updated / changed after it was committed to `main`. ⚠️** +//! +//! However, there is one exception to this rule: if the new version has the same outcome when applied successfully. +//! This can be due to: +//! +//! - **Optimization:** The migration script turns out to be too slow in production workloads, but you find a better +//! version that does the same but runs faster. +//! - **Failure:** The script worked fine during testing but in prod it always fails, e.g. because it is missing NULL +//! handling. It is important to remember that the fix MUST NOT change the outcome of the successful runs. +//! - **Idempotency:** The script works only w/o transactions (see section above) and cannot be re-applied when +//! interrupted midway. One common case is `CREATE INDEX CONCURRENTLY ...` where you MUST drop the index beforehand +//! via `DROP INDEX IF EXISTS ...` because a previous interrupted migration might have left it in an invalid state. +//! See ["Building Indexes Concurrently"]. +//! +//! If you are very sure that you found a fix for your migration that does the same operation, you still MUST NOT just +//! change the existing migration. The reason is that we keep a checksum of the migration stored in the database. +//! Changing the script will change the checksum, which will lead to a [failure](MigrateError::VersionMismatch) when +//! running the migrations. You can work around that by obtaining the old checksum (in hex) and adding it to the new +//! version as: `-- IOX_OTHER_CHECKSUM: 42feedbull`. This pragma can be repeated multiple times. +//! +//! ### Example +//! +//! If the old migration script looks like this: +//! +//! ```sql +//! -- IOX_NO_TRANSACTION +//! SET statement_timeout TO '60min'; +//! +//! -- IOX_STEP_BOUNDARY +//! +//! -- IOX_NO_TRANSACTION +//! CREATE INDEX CONCURRENTLY IF NOT EXISTS i ON t (x); +//! ``` +//! +//! You can fix the idempotency by creating a new migration that contains: +//! +//! ```sql +//! -- IOX_OTHER_CHECKSUM: 067431eaa74f26ee86200aaed4992a5fe22354322102f1ed795e424ec529469079569072d856e96ee9fdb6cc848b6137 +//! -- IOX_NO_TRANSACTION +//! SET statement_timeout TO '60min'; +//! +//! -- IOX_STEP_BOUNDARY +//! DROP INDEX CONCURRENTLY IF EXISTS i; +//! +//! -- IOX_NO_TRANSACTION +//! +//! -- IOX_STEP_BOUNDARY +//! +//! -- IOX_NO_TRANSACTION +//! CREATE INDEX CONCURRENTLY IF NOT EXISTS i ON t (x); +//! ``` +//! +//! ## Non-SQL steps +//! +//! At the moment, we only support SQL-based migration steps, but other step types can easily be added. +//! +//! ["Building Indexes Concurrently"]: https://www.postgresql.org/docs/15/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY + +use std::{ + borrow::Cow, + collections::{HashMap, HashSet}, + hash::{Hash, Hasher}, + ops::Deref, + str::FromStr, + time::{Duration, Instant}, +}; + +use async_trait::async_trait; +use observability_deps::tracing::{debug, info, warn}; +use siphasher::sip::SipHasher13; +use sqlx::{ + migrate::{Migrate, MigrateError, Migration, MigrationType, Migrator}, + query, query_as, query_scalar, Acquire, Connection, Executor, PgConnection, Postgres, + Transaction, +}; + +/// A single [`IOxMigration`] step. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum IOxMigrationStep { + /// Execute a SQL statement. + /// + /// A SQL statement MAY contain multiple sub-statements, e.g.: + /// + /// ```sql + /// CREATE TABLE IF NOT EXISTS table1 ( + /// id BIGINT GENERATED ALWAYS AS IDENTITY, + /// PRIMARY KEY (id), + /// ); + /// + /// CREATE TABLE IF NOT EXISTS table2 ( + /// id BIGINT GENERATED ALWAYS AS IDENTITY, + /// PRIMARY KEY (id), + /// ); + /// ``` + SqlStatement { + /// The SQL text. + /// + /// If [`in_transaction`](Self::SqlStatement::in_transaction) is set, this MUST NOT contain any transaction + /// modifiers like `COMMIT`/`ROLLBACK`/`BEGIN`! + sql: Cow<'static, str>, + + /// Should the execution of the SQL text be wrapped into a transaction? + /// + /// Whenever possible, you likely want to set this to `true`. However, some database changes like `CREATE INDEX + /// CONCURRENTLY` under PostgreSQL cannot be executed within a transaction. + in_transaction: bool, + }, +} + +impl IOxMigrationStep { + /// Apply migration step. + async fn apply(&self, conn: &mut C) -> Result<(), MigrateError> + where + C: IOxMigrate, + { + match self { + Self::SqlStatement { sql, .. } => { + conn.exec(sql).await?; + } + } + + Ok(()) + } + + /// Will this step set up a transaction if there is none yet? + fn in_transaction(&self) -> bool { + match self { + Self::SqlStatement { in_transaction, .. } => *in_transaction, + } + } +} + +/// Migration checksum. +#[derive(Clone, PartialEq, Eq)] +pub struct Checksum(Box<[u8]>); + +impl Checksum { + fn as_bytes(&self) -> &[u8] { + &self.0 + } +} + +impl std::fmt::Debug for Checksum { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for b in &*self.0 { + write!(f, "{:02x}", b)?; + } + Ok(()) + } +} + +impl std::fmt::Display for Checksum { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl From<[u8; N]> for Checksum { + fn from(value: [u8; N]) -> Self { + Self(value.into()) + } +} + +impl From<&[u8]> for Checksum { + fn from(value: &[u8]) -> Self { + Self(value.into()) + } +} + +impl FromStr for Checksum { + type Err = MigrateError; + + fn from_str(s: &str) -> Result { + let inner = (0..s.len()) + .step_by(2) + .map(|i| u8::from_str_radix(&s[i..(i + 2).min(s.len())], 16)) + .collect::, _>>() + .map_err(|e| { + MigrateError::Source(format!("cannot parse checksum '{s}': {e}").into()) + })?; + + Ok(Self(inner)) + } +} + +/// Database migration. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct IOxMigration { + /// Version. + /// + /// This is used to order migrations. + pub version: i64, + + /// Human-readable description. + pub description: Cow<'static, str>, + + /// Steps that compose this migration. + /// + /// In most cases you want a single [SQL step](IOxMigrationStep::SqlStatement) which is executed + /// [in a transaction](IOxMigrationStep::SqlStatement::in_transaction). + pub steps: Box<[IOxMigrationStep]>, + + /// Checksum of the given steps. + pub checksum: Checksum, + + /// Checksums of other versions of this migration that are known to be compatible. + /// + /// **Using this should be a rare exception!** + /// + /// This can be used to convert a non-idempotent migration into an idempotent one. + pub other_compatible_checksums: Box<[Checksum]>, +} + +impl IOxMigration { + /// Apply migration and return elapsed wall-clock time (measured locally). + async fn apply(&self, conn: &mut C) -> Result + where + C: IOxMigrate, + { + let single_transaction = self.single_transaction(); + info!( + version = self.version, + description = self.description.as_ref(), + steps = self.steps.len(), + single_transaction, + "applying migration" + ); + + let elapsed = if single_transaction { + let mut txn = conn.begin_txn().await?; + let elapsed = { + let conn = txn.acquire_conn().await?; + self.apply_inner(conn, true).await? + }; + txn.commit_txn().await?; + elapsed + } else { + self.apply_inner(conn, false).await? + }; + + info!( + version = self.version, + description = self.description.as_ref(), + steps = self.steps.len(), + elapsed_secs = elapsed.as_secs_f64(), + "migration applied" + ); + + Ok(elapsed) + } + + /// Run actual application of the migration. + /// + /// This may or may NOT be guarded by a transaction block. + async fn apply_inner(&self, conn: &mut C, single_txn: bool) -> Result + where + C: IOxMigrate, + { + let start = Instant::now(); + conn.start_migration(self).await?; + + for (i, step) in self.steps.iter().enumerate() { + info!( + version = self.version, + steps = self.steps.len(), + step = i + 1, + single_txn, + in_transaction = step.in_transaction(), + "applying migration step" + ); + + if step.in_transaction() && !single_txn { + let mut txn = conn.begin_txn().await?; + { + let conn = txn.acquire_conn().await?; + step.apply(conn).await?; + } + txn.commit_txn().await?; + } else { + step.apply(conn).await?; + } + + info!( + version = self.version, + steps = self.steps.len(), + step = i + 1, + "applied migration step" + ); + } + + let elapsed = start.elapsed(); + conn.run_sanity_checks().await?; + conn.finish_migration(self, elapsed).await?; + + Ok(elapsed) + } + + /// This migration can be run in a single transaction and will never be dirty. + pub fn single_transaction(&self) -> bool { + self.steps.iter().all(|s| s.in_transaction()) + } +} + +impl TryFrom<&Migration> for IOxMigration { + type Error = MigrateError; + + fn try_from(migration: &Migration) -> Result { + if migration.migration_type != MigrationType::Simple { + return Err(MigrateError::Source( + format!( + "migration type has to be simple but is {:?}", + migration.migration_type + ) + .into(), + )); + } + + let other_compatible_checksums = migration + .sql + .lines() + .filter_map(|s| { + s.strip_prefix("-- IOX_OTHER_CHECKSUM:") + .map(|s| s.trim().parse()) + }) + .collect::>()?; + + let steps = migration + .sql + .split("-- IOX_STEP_BOUNDARY") + .map(|sql| { + let sql = sql.trim().to_owned(); + let in_transaction = !sql.contains("IOX_NO_TRANSACTION"); + IOxMigrationStep::SqlStatement { + sql: sql.into(), + in_transaction, + } + }) + .collect(); + + Ok(Self { + version: migration.version, + description: migration.description.clone(), + steps, + // Keep original (unprocessed) checksum for backwards compatibility. + checksum: migration.checksum.deref().into(), + other_compatible_checksums, + }) + } +} + +/// Migration manager. +#[derive(Debug, PartialEq, Eq)] +pub struct IOxMigrator { + /// List of migrations. + migrations: Vec, +} + +impl IOxMigrator { + /// Create new migrator. + /// + /// # Error + /// Fails if migrations are not sorted or if there are duplicate [versions](IOxMigration::version). + pub fn try_new( + migrations: impl IntoIterator, + ) -> Result { + let migrations = migrations.into_iter().collect::>(); + + if let Some(m) = migrations.windows(2).find(|m| m[0].version > m[1].version) { + return Err(MigrateError::Source( + format!( + "migrations are not sorted: version {} is before {} but should not be", + m[0].version, m[1].version, + ) + .into(), + )); + } + if let Some(m) = migrations.windows(2).find(|m| m[0].version == m[1].version) { + return Err(MigrateError::Source( + format!( + "migrations are not unique: version {} found twice", + m[0].version, + ) + .into(), + )); + } + + Ok(Self { migrations }) + } + + /// Run migrator on connection/pool. + /// + /// Returns set of executed [migrations](IOxMigration). + /// + /// This may fail and some migrations may be applied. Also, it is possible that a migration itself fails half-way, + /// in which case it is marked as dirty. Subsequent migrations will fail until the issue is resolved. + pub async fn run<'a, A>(&self, migrator: A) -> Result, MigrateError> + where + A: Acquire<'a> + Send, + ::Target: IOxMigrate, + { + let mut conn = migrator.acquire().await?; + self.run_direct(&mut *conn).await + } + + /// Run migrator on open connection. + /// + /// See docs for [run](Self::run). + async fn run_direct(&self, conn: &mut C) -> Result, MigrateError> + where + C: IOxMigrate, + { + let lock_id = conn.generate_lock_id().await?; + ::lock(conn, lock_id).await?; + + let run_res = self.run_inner(conn).await; + + // always try to unlock, even when we failed. + // While PG is timing out the lock, unlocking manually will give others the chance to re-lock faster. This is + // mostly relevant for tests where we re-use connections. + let unlock_res = ::unlock(conn, lock_id).await; + + // return first error but also first OK (there doesn't seem to be an stdlib method for this) + match (run_res, unlock_res) { + (Err(e), _) => Err(e), + (Ok(_), Err(e)) => Err(e), + (Ok(res), Ok(())) => Ok(res), + } + } + + /// Run migrator. + /// + /// This expects that locking was already done. + async fn run_inner(&self, conn: &mut C) -> Result, MigrateError> + where + C: IOxMigrate, + { + // creates [_migrations] table only if needed + // eventually this will likely migrate previous versions of the table + conn.ensure_migrations_table().await?; + + let applied_migrations = ::list_applied_migrations(conn).await?; + validate_applied_migrations(&applied_migrations, self)?; + + let applied_and_not_dirty: HashSet<_> = applied_migrations + .into_iter() + .filter(|m| !m.dirty) + .map(|m| m.version) + .collect(); + + let mut new_migrations = HashSet::new(); + for migration in &self.migrations { + if applied_and_not_dirty.contains(&migration.version) { + continue; + } + migration.apply(conn).await?; + new_migrations.insert(migration.version); + } + + Ok(new_migrations) + } +} + +impl TryFrom<&Migrator> for IOxMigrator { + type Error = MigrateError; + + fn try_from(migrator: &Migrator) -> Result { + if migrator.ignore_missing { + return Err(MigrateError::Source( + "`Migrator::ignore_missing` MUST NOT be set" + .to_owned() + .into(), + )); + } + if !migrator.locking { + return Err(MigrateError::Source( + "`Migrator::locking` MUST be set".to_owned().into(), + )); + } + + let migrations = migrator + .migrations + .iter() + .map(|migration| migration.try_into()) + .collect::, _>>()?; + + Self::try_new(migrations) + } +} + +/// Validate already-applied migrations +/// +/// Checks that: +/// +/// - all applied migrations are known or all known migrations are applied +/// - checksum of applied migration and known migration match +/// - new migrations are newer than both the successfully applied and the dirty version +/// - there is at most one dirty migration (bug check) +/// - the dirty migration is the last applied one (bug check) +fn validate_applied_migrations( + applied_migrations: &[IOxAppliedMigration], + migrator: &IOxMigrator, +) -> Result<(), MigrateError> { + let migrations: HashMap<_, _> = migrator.migrations.iter().map(|m| (m.version, m)).collect(); + + let mut dirty_version = None; + for (idx, applied_migration) in applied_migrations.iter().enumerate() { + match migrations.get(&applied_migration.version) { + None => { + if idx == migrations.len() && dirty_version.is_none() { + // All migrations in `migrator` have been applied + // We therefore continue as this should not prevent startup + // if there are no local migrations to apply + warn!("found applied migrations not present locally, but all local migrations applied - continuing"); + return Ok(()); + } + + return Err(MigrateError::VersionMissing(applied_migration.version)); + } + Some(migration) => { + if !std::iter::once(&migration.checksum) + .chain(migration.other_compatible_checksums.iter()) + .any(|cs| cs.as_bytes() == applied_migration.checksum.deref()) + { + return Err(MigrateError::VersionMismatch(migration.version)); + } + + if applied_migration.dirty { + if let Some(first) = dirty_version { + return Err(MigrateError::Source(format!( + "there are multiple dirty versions, this should not happen and is considered a bug: {:?}", + &[first, migration.version], + ).into())); + } + dirty_version = Some(migration.version); + warn!( + version = migration.version, + "found dirty migration, trying to recover" + ); + } + } + } + } + + let applied_last = applied_migrations + .iter() + .filter(|m| Some(m.version) != dirty_version) + .map(|m| m.version) + .max(); + if let (Some(applied_last), Some(dirty_version)) = (applied_last, dirty_version) { + // algorithm error in this method, use an assertion + assert_ne!(applied_last, dirty_version); + + if applied_last > dirty_version { + // database state error, so use a proper error + return Err(MigrateError::Source(format!( + "dirty version ({dirty_version}) is not the last applied version ({applied_last}), this is a bug", + ).into())); + } + } + + let applied_set = applied_migrations + .iter() + .map(|m| m.version) + .collect::>(); + let new_first = migrator + .migrations + .iter() + .filter(|m| !applied_set.contains(&m.version)) + .map(|m| m.version) + .min(); + if let (Some(dirty_version), Some(new_first)) = (dirty_version, new_first) { + // algorithm error in this method, use an assertion + assert_ne!(dirty_version, new_first); + + if dirty_version > new_first { + // database state error, so use a proper error + return Err(MigrateError::Source( + format!( + "new migration ({new_first}) goes before dirty version ({dirty_version}), \ + this should not have been merged!", + ) + .into(), + )); + } + } + if let (Some(applied_last), Some(new_first)) = (applied_last, new_first) { + // algorithm error in this method, use an assertion + assert_ne!(applied_last, new_first); + + if applied_last > new_first { + // database state error, so use a proper error + return Err(MigrateError::Source( + format!( + "new migration ({new_first}) goes before last applied migration ({applied_last}), \ + this should not have been merged!", + ) + .into(), + )); + } + } + + Ok(()) +} + +/// Information about a migration found in the database. +#[derive(Debug)] +pub struct IOxAppliedMigration { + /// Version of the migration. + pub version: i64, + + /// Checksum. + pub checksum: Cow<'static, [u8]>, + + /// Dirty flag. + /// + /// If this is set, then the migration was interrupted midway. + pub dirty: bool, +} + +/// Transaction type linked to [`IOxMigrate`]. +/// +/// This is a separate type because we need to own the transaction object at some point before handing out mutable +/// borrows to the actual connection again. +#[async_trait] +pub trait IOxMigrateTxn: Send { + /// The migration interface. + type M: IOxMigrate; + + /// Acquire connection. + async fn acquire_conn(&mut self) -> Result<&mut Self::M, MigrateError>; + + /// Commit transaction. + async fn commit_txn(self) -> Result<(), MigrateError>; +} + +/// Interface of a specific database implementation (like Postgres) and the IOx migration system. +/// +/// This mostly delegates to the SQLx [`Migrate`] interface but also has some extra methods. +#[async_trait] +pub trait IOxMigrate: Connection + Migrate + Send { + /// Transaction type. + type Txn<'a>: IOxMigrateTxn + where + Self: 'a; + + /// Start new transaction. + async fn begin_txn<'a>(&'a mut self) -> Result, MigrateError>; + + /// Generate a lock ID that is used for [`lock`](Self::lock) and [`unlock`](Self::unlock). + async fn generate_lock_id(&mut self) -> Result; + + /// Lock database for migrations. + async fn lock(&mut self, lock_id: i64) -> Result<(), MigrateError>; + + /// Unlock database after migration. + async fn unlock(&mut self, lock_id: i64) -> Result<(), MigrateError>; + + /// Get list of applied migrations. + async fn list_applied_migrations(&mut self) -> Result, MigrateError>; + + /// Start a migration and mark it as "not finished". + async fn start_migration(&mut self, migration: &IOxMigration) -> Result<(), MigrateError>; + + /// Finish a migration and register the elapsed time. + async fn finish_migration( + &mut self, + migration: &IOxMigration, + elapsed: Duration, + ) -> Result<(), MigrateError>; + + /// Execute a SQL statement (that may contain multiple sub-statements) + async fn exec(&mut self, sql: &str) -> Result<(), MigrateError>; + + /// Run DB-specific sanity checks on the schema. + /// + /// This mostly includes checks for "validity" markers (e.g. for indices). + async fn run_sanity_checks(&mut self) -> Result<(), MigrateError>; +} + +#[async_trait] +impl<'a> IOxMigrateTxn for Transaction<'a, Postgres> { + type M = PgConnection; + + async fn acquire_conn(&mut self) -> Result<&mut Self::M, MigrateError> { + let conn = self.acquire().await?; + Ok(conn) + } + + async fn commit_txn(self) -> Result<(), MigrateError> { + self.commit().await?; + Ok(()) + } +} + +#[async_trait] +impl IOxMigrate for PgConnection { + type Txn<'a> = Transaction<'a, Postgres>; + + async fn begin_txn<'a>(&'a mut self) -> Result, MigrateError> { + let txn = ::begin(self).await?; + Ok(txn) + } + + async fn generate_lock_id(&mut self) -> Result { + let db: String = query_scalar("SELECT current_database()") + .fetch_one(self) + .await?; + + // A randomly generated static siphash key to ensure all migrations use the same locks. + // + // Generated with: xxd -i -l 16 /dev/urandom + let key = [ + 0xb8, 0x52, 0x81, 0x3c, 0x12, 0x83, 0x6f, 0xd9, 0x00, 0x4f, 0xe7, 0xe3, 0x61, 0xbd, + 0x03, 0xaf, + ]; + + let mut hasher = SipHasher13::new_with_key(&key); + db.hash(&mut hasher); + + Ok(i64::from_ne_bytes(hasher.finish().to_ne_bytes())) + } + + async fn lock(&mut self, lock_id: i64) -> Result<(), MigrateError> { + loop { + let is_locked: bool = query_scalar("SELECT pg_try_advisory_lock($1)") + .bind(lock_id) + .fetch_one(&mut *self) + .await?; + + if is_locked { + return Ok(()); + } + + let t_wait = Duration::from_millis(20); + debug!( + lock_id, + t_wait_millis = t_wait.as_millis(), + "lock held, waiting" + ); + tokio::time::sleep(t_wait).await; + } + } + + async fn unlock(&mut self, lock_id: i64) -> Result<(), MigrateError> { + let was_locked: bool = query_scalar("SELECT pg_advisory_unlock($1)") + .bind(lock_id) + .fetch_one(self) + .await?; + + if !was_locked { + return Err(MigrateError::Source( + format!("did not own lock: {lock_id}").into(), + )); + } + + Ok(()) + } + + async fn list_applied_migrations(&mut self) -> Result, MigrateError> { + let rows: Vec<(i64, Vec, bool)> = query_as( + "SELECT version, checksum, NOT success FROM _sqlx_migrations ORDER BY version", + ) + .fetch_all(self) + .await?; + + let migrations = rows + .into_iter() + .map(|(version, checksum, dirty)| IOxAppliedMigration { + version, + checksum: checksum.into(), + dirty, + }) + .collect(); + + Ok(migrations) + } + + async fn start_migration(&mut self, migration: &IOxMigration) -> Result<(), MigrateError> { + let _ = query( + r#" +INSERT INTO _sqlx_migrations ( version, description, success, checksum, execution_time ) +VALUES ( $1, $2, FALSE, $3, -1 ) +ON CONFLICT (version) +DO NOTHING + "#, + ) + .bind(migration.version) + .bind(&*migration.description) + .bind(migration.checksum.as_bytes()) + .execute(self) + .await?; + + Ok(()) + } + + async fn finish_migration( + &mut self, + migration: &IOxMigration, + elapsed: Duration, + ) -> Result<(), MigrateError> { + let _ = query( + r#" +UPDATE _sqlx_migrations +SET success = TRUE, execution_time = $1 +WHERE version = $2 + "#, + ) + .bind(elapsed.as_nanos() as i64) + .bind(migration.version) + .execute(self) + .await?; + + Ok(()) + } + + async fn exec(&mut self, sql: &str) -> Result<(), MigrateError> { + let _ = self.execute(sql).await?; + Ok(()) + } + + async fn run_sanity_checks(&mut self) -> Result<(), MigrateError> { + let dirty_indices: Vec = query_scalar( + r#" +SELECT pg_class.relname +FROM pg_index +JOIN pg_class ON pg_index.indexrelid = pg_class.oid +JOIN pg_namespace ON pg_class.relnamespace = pg_namespace.oid +WHERE pg_namespace.nspname = current_schema() AND NOT pg_index.indisvalid +ORDER BY pg_class.relname + "#, + ) + .fetch_all(self) + .await?; + + if !dirty_indices.is_empty() { + return Err(MigrateError::Source( + format!("Found invalid indexes: {}", dirty_indices.join(", ")).into(), + )); + } + + Ok(()) + } +} + +/// Testing tools for migrations. +#[cfg(test)] +pub mod test_utils { + use super::*; + + use std::future::Future; + + /// Test migration. + /// + /// This runs the migrations to check if they pass. The given factory must provide an empty schema (i.e. w/o any + /// migrations applied). + /// + /// # Tests + /// + /// This tests that: + /// + /// - **run once:** All migrations work when ran once. + /// - **idempotency:** Migrations marked as [`idempotent`](IOxMigration::idempotent) can be executed twice. + /// + /// # Error + /// + /// Fails if this finds a bug. + pub async fn test_migration( + migrator: &IOxMigrator, + factory: Factory, + ) -> Result<(), MigrateError> + where + Factory: (Fn() -> FactoryFut) + Send + Sync, + FactoryFut: Future + Send, + Pool: Send, + for<'a> &'a Pool: Acquire<'a> + Send, + for<'a> <<&'a Pool as Acquire<'a>>::Connection as Deref>::Target: IOxMigrate, + { + { + info!("test: run all migrations"); + let conn = factory().await; + let applied = migrator.run(&conn).await?; + assert_eq!(applied.len(), migrator.migrations.len()); + } + + info!("interrupt non-transaction migrations"); + for (idx_m, m) in migrator.migrations.iter().enumerate() { + if m.single_transaction() { + info!( + version = m.version, + "skip migration because single transaction property" + ); + continue; + } + + let steps = m.steps.len(); + info!( + version = m.version, + steps, "found non-transactional migration" + ); + + for step in 1..(steps + 1) { + info!(version = m.version, steps, step, "test: die after step"); + + let broken_cmd = "iox_this_is_a_broken_test_cmd"; + let migrator_broken = IOxMigrator::try_new( + migrator + .migrations + .iter() + .take(idx_m) + .cloned() + .chain(std::iter::once(IOxMigration { + steps: m + .steps + .iter() + .take(step) + .cloned() + .chain(std::iter::once(IOxMigrationStep::SqlStatement { + sql: broken_cmd.into(), + in_transaction: false, + })) + .collect(), + ..m.clone() + })), + ) + .expect("bug in test"); + + let conn = factory().await; + let err = migrator_broken.run(&conn).await.unwrap_err(); + if !err.to_string().contains(broken_cmd) { + panic!("migrator broke in expected way, bug in test setup: {err}"); + } + + info!( + version = m.version, + steps, step, "test: die after step, recover from error" + ); + let applied = migrator.run(&conn).await?; + assert!(applied.contains(&m.version)); + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + mod generic { + use super::*; + + use proptest::prelude::*; + + proptest! { + #[test] + fn test_checksum_string_roundtrip(s: Vec) { + let checksum_1 = Checksum::from(s.as_slice()); + let string_1 = checksum_1.to_string(); + let checksum_2 = Checksum::from_str(&string_1).unwrap(); + let string_2 = checksum_2.to_string(); + assert_eq!(checksum_1, checksum_2); + assert_eq!(string_1, string_2); + } + } + + #[test] + fn test_parse_valid_checksum() { + let actual = Checksum::from_str( + "b88c635e27f8b9ba8547b24efcb081429a8f3e85b70f35916e1900dffc4e6a77eed8a02acc7c72526dd7d50166b63fbd" + ).unwrap(); + let expected = Checksum::from([ + 184, 140, 99, 94, 39, 248, 185, 186, 133, 71, 178, 78, 252, 176, 129, 66, 154, 143, + 62, 133, 183, 15, 53, 145, 110, 25, 0, 223, 252, 78, 106, 119, 238, 216, 160, 42, + 204, 124, 114, 82, 109, 215, 213, 1, 102, 182, 63, 189, + ]); + + assert_eq!(actual, expected); + } + + #[test] + fn test_parse_invalid_checksum() { + let err = Checksum::from_str("foo").unwrap_err(); + + assert_eq!( + err.to_string(), + "while resolving migrations: cannot parse checksum 'foo': invalid digit found in string", + ); + } + + #[test] + fn test_migrator_new_error_not_sorted() { + let err = IOxMigrator::try_new([ + IOxMigration { + version: 2, + description: "".into(), + steps: [].into(), + checksum: [].into(), + other_compatible_checksums: [].into(), + }, + IOxMigration { + version: 1, + description: "".into(), + steps: [].into(), + checksum: [].into(), + other_compatible_checksums: [].into(), + }, + ]) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "while resolving migrations: migrations are not sorted: version 2 is before 1 but should not be", + ); + } + + #[test] + fn test_migrator_new_error_not_unique() { + let err = IOxMigrator::try_new([ + IOxMigration { + version: 2, + description: "".into(), + steps: [].into(), + checksum: [].into(), + other_compatible_checksums: [].into(), + }, + IOxMigration { + version: 2, + description: "".into(), + steps: [].into(), + checksum: [].into(), + other_compatible_checksums: [].into(), + }, + ]) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "while resolving migrations: migrations are not unique: version 2 found twice", + ); + } + + #[test] + fn test_convert_migrator_from_sqlx_error_no_locking() { + let err = IOxMigrator::try_from(&Migrator { + migrations: vec![].into(), + ignore_missing: false, + locking: false, + }) + .unwrap_err(); + assert_eq!( + err.to_string(), + "while resolving migrations: `Migrator::locking` MUST be set", + ); + } + + #[test] + fn test_convert_migrator_from_sqlx_error_ignore_missing() { + let err = IOxMigrator::try_from(&Migrator { + migrations: vec![].into(), + ignore_missing: true, + locking: true, + }) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "while resolving migrations: `Migrator::ignore_missing` MUST NOT be set", + ); + } + + #[test] + fn test_convert_migrator_from_sqlx_error_invalid_migration_type_rev_up() { + let err = IOxMigrator::try_from(&Migrator { + migrations: vec![Migration { + version: 1, + description: "".into(), + migration_type: MigrationType::ReversibleUp, + sql: "".into(), + checksum: vec![].into(), + }] + .into(), + ignore_missing: false, + locking: true, + }) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "while resolving migrations: migration type has to be simple but is ReversibleUp", + ); + } + + #[test] + fn test_convert_migrator_from_sqlx_error_invalid_migration_type_rev_down() { + let err = IOxMigrator::try_from(&Migrator { + migrations: vec![Migration { + version: 1, + description: "".into(), + migration_type: MigrationType::ReversibleDown, + sql: "".into(), + checksum: vec![].into(), + }] + .into(), + ignore_missing: false, + locking: true, + }) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "while resolving migrations: migration type has to be simple but is ReversibleDown", + ); + } + + #[test] + fn test_convert_migrator_from_sqlx_error_invalid_other_compatible_checksum() { + let err = IOxMigrator::try_from(&Migrator { + migrations: vec![Migration { + version: 1, + description: "".into(), + migration_type: MigrationType::Simple, + sql: "-- IOX_OTHER_CHECKSUM: foo".into(), + checksum: vec![].into(), + }] + .into(), + ignore_missing: false, + locking: true, + }) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "while resolving migrations: cannot parse checksum 'foo': invalid digit found in string", + ); + } + + #[test] + fn test_convert_migrator_from_sqlx_ok() { + let actual = IOxMigrator::try_from(&Migrator { + migrations: vec![ + Migration { + version: 1, + description: "some descr".into(), + migration_type: MigrationType::Simple, + sql: "SELECT 1;".into(), + checksum: vec![1, 2, 3].into(), + }, + Migration { + version: 10, + description: "more descr".into(), + migration_type: MigrationType::Simple, + sql: "SELECT 2;\n-- IOX_STEP_BOUNDARY\n-- IOX_NO_TRANSACTION\nSELECT 3;" + .into(), + checksum: vec![4, 5, 6].into(), + }, + Migration { + version: 11, + description: "xxx".into(), + migration_type: MigrationType::Simple, + sql: "-- IOX_OTHER_CHECKSUM:1ff\n-- IOX_OTHER_CHECKSUM: 2ff \nSELECT4;" + .into(), + checksum: vec![7, 8, 9].into(), + }, + ] + .into(), + ignore_missing: false, + locking: true, + }) + .unwrap(); + + let expected = IOxMigrator { + migrations: vec![ + IOxMigration { + version: 1, + description: "some descr".into(), + steps: [IOxMigrationStep::SqlStatement { + sql: "SELECT 1;".into(), + in_transaction: true, + }] + .into(), + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }, + IOxMigration { + version: 10, + description: "more descr".into(), + steps: [ + IOxMigrationStep::SqlStatement { + sql: "SELECT 2;".into(), + in_transaction: true, + }, + IOxMigrationStep::SqlStatement { + sql: "-- IOX_NO_TRANSACTION\nSELECT 3;".into(), + in_transaction: false, + }, + ] + .into(), + checksum: [4, 5, 6].into(), + other_compatible_checksums: [].into(), + }, + IOxMigration { + version: 11, + description: "xxx".into(), + steps: [IOxMigrationStep::SqlStatement { + sql: "-- IOX_OTHER_CHECKSUM:1ff\n-- IOX_OTHER_CHECKSUM: 2ff \nSELECT4;".into(), + in_transaction: true, + }] + .into(), + checksum: [7, 8, 9].into(), + other_compatible_checksums: [ + Checksum::from_str("1ff").unwrap(), + Checksum::from_str("2ff").unwrap(), + ].into(), + }, + ], + }; + + assert_eq!(actual, expected); + } + } + + mod postgres { + use std::sync::Arc; + + use futures::{stream::FuturesUnordered, StreamExt}; + use sqlx::{pool::PoolConnection, Postgres}; + use sqlx_hotswap_pool::HotSwapPool; + use test_helpers::maybe_start_logging; + + use crate::postgres::test_utils::{maybe_skip_integration, setup_db_no_migration}; + + use super::*; + + #[tokio::test] + async fn test_lock_id_deterministic() { + maybe_skip_integration!(); + + let mut conn = setup().await; + let conn = &mut *conn; + + let first = conn.generate_lock_id().await.unwrap(); + let second = conn.generate_lock_id().await.unwrap(); + assert_eq!(first, second); + } + + #[tokio::test] + async fn test_lock_unlock_twice() { + maybe_skip_integration!(); + + let mut conn = setup().await; + let conn = &mut *conn; + + let lock_id = conn.generate_lock_id().await.unwrap(); + + ::lock(conn, lock_id) + .await + .unwrap(); + ::unlock(conn, lock_id) + .await + .unwrap(); + + ::lock(conn, lock_id) + .await + .unwrap(); + ::unlock(conn, lock_id) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_lock_prevents_2nd_lock() { + maybe_skip_integration!(); + + let pool = setup_pool().await; + + let mut conn1 = pool.acquire().await.unwrap(); + let conn1 = &mut *conn1; + + let mut conn2 = pool.acquire().await.unwrap(); + let conn2 = &mut *conn2; + + let lock_id = conn1.generate_lock_id().await.unwrap(); + + ::lock(conn1, lock_id) + .await + .unwrap(); + tokio::time::timeout(Duration::from_secs(1), async { + ::lock(conn2, lock_id) + .await + .unwrap(); + }) + .await + .unwrap_err(); + ::unlock(conn1, lock_id) + .await + .unwrap(); + + ::lock(conn2, lock_id) + .await + .unwrap(); + ::unlock(conn2, lock_id) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_locks_are_scoped() { + maybe_skip_integration!(); + + let pool = setup_pool().await; + + let mut conn1 = pool.acquire().await.unwrap(); + let conn1 = &mut *conn1; + + let mut conn2 = pool.acquire().await.unwrap(); + let conn2 = &mut *conn2; + + let lock_id1 = conn1.generate_lock_id().await.unwrap(); + let lock_id2 = !lock_id1; + + ::lock(conn1, lock_id1) + .await + .unwrap(); + ::lock(conn1, lock_id2) + .await + .unwrap(); + ::unlock(conn1, lock_id1) + .await + .unwrap(); + + // id2 is still lock (i.e. unlock is also scoped) + tokio::time::timeout(Duration::from_secs(1), async { + ::lock(conn2, lock_id2) + .await + .unwrap(); + }) + .await + .unwrap_err(); + + ::unlock(conn1, lock_id2) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_unlock_without_lock_fails() { + maybe_skip_integration!(); + + let mut conn = setup().await; + let conn = &mut *conn; + + let lock_id = conn.generate_lock_id().await.unwrap(); + + let err = ::unlock(conn, lock_id) + .await + .unwrap_err(); + + assert_starts_with( + &err.to_string(), + "while resolving migrations: did not own lock:", + ); + } + + #[tokio::test] + async fn test_step_sql_statement_no_transaction() { + maybe_skip_integration!(); + + for in_transaction in [false, true] { + println!("in_transaction: {in_transaction}"); + + let mut conn = setup().await; + let conn = &mut *conn; + + conn.execute("CREATE TABLE t (x INT);").await.unwrap(); + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 1, + description: "".into(), + steps: [IOxMigrationStep::SqlStatement { + sql: "CREATE INDEX CONCURRENTLY i ON t (x);".into(), + in_transaction, + }] + .into(), + checksum: [].into(), + other_compatible_checksums: [].into(), + }]) + .unwrap(); + let res = migrator.run_direct(conn).await; + + match in_transaction { + false => { + assert_eq!(res.unwrap(), HashSet::from([1]),); + } + true => { + // `CREATE INDEX CONCURRENTLY` is NOT possible w/ a transaction. Verify that. + assert_eq!( + res.unwrap_err().to_string(), + "while executing migrations: error returned from database: \ + CREATE INDEX CONCURRENTLY cannot run inside a transaction block", + ); + } + } + } + } + + #[tokio::test] + async fn test_migrator_happy_path() { + maybe_skip_integration!(); + let mut conn = setup().await; + let conn = &mut *conn; + + let migrator = IOxMigrator::try_new([ + IOxMigration { + version: 1, + description: "".into(), + steps: [ + IOxMigrationStep::SqlStatement { + sql: "CREATE TABLE t (x INT);".into(), + in_transaction: false, + }, + IOxMigrationStep::SqlStatement { + sql: "INSERT INTO t (x) VALUES (1); INSERT INTO t (x) VALUES (10);" + .into(), + in_transaction: true, + }, + ] + .into(), + checksum: [].into(), + other_compatible_checksums: [].into(), + }, + IOxMigration { + version: 2, + description: "".into(), + steps: [IOxMigrationStep::SqlStatement { + sql: "INSERT INTO t (x) VALUES (100);".into(), + in_transaction: true, + }] + .into(), + checksum: [].into(), + other_compatible_checksums: [].into(), + }, + ]) + .unwrap(); + + let applied = migrator.run_direct(conn).await.unwrap(); + assert_eq!(applied, HashSet::from([1, 2])); + + let r: i32 = query_scalar("SELECT SUM(x)::INT AS r FROM t;") + .fetch_one(conn) + .await + .unwrap(); + + assert_eq!(r, 111); + } + + #[tokio::test] + async fn test_migrator_only_apply_new_migrations() { + maybe_skip_integration!(); + let mut conn = setup().await; + let conn = &mut *conn; + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 1, + description: "".into(), + steps: [IOxMigrationStep::SqlStatement { + // NOT idempotent! + sql: "CREATE TABLE t (x INT);".into(), + in_transaction: false, + }] + .into(), + checksum: [].into(), + other_compatible_checksums: [].into(), + }]) + .unwrap(); + + let applied = migrator.run_direct(conn).await.unwrap(); + assert_eq!(applied, HashSet::from([1])); + + let migrator = IOxMigrator::try_new( + migrator.migrations.iter().cloned().chain([IOxMigration { + version: 2, + description: "".into(), + steps: [IOxMigrationStep::SqlStatement { + // NOT idempotent! + sql: "CREATE TABLE s (x INT);".into(), + in_transaction: false, + }] + .into(), + checksum: [].into(), + other_compatible_checksums: [].into(), + }]), + ) + .unwrap(); + + let applied = migrator.run_direct(conn).await.unwrap(); + assert_eq!(applied, HashSet::from([2])); + + let applied = migrator.run_direct(conn).await.unwrap(); + assert_eq!(applied, HashSet::from([])); + } + + #[tokio::test] + async fn test_migrator_fail_clean_migration_missing() { + maybe_skip_integration!(); + let mut conn = setup().await; + let conn = &mut *conn; + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 1, + description: "".into(), + steps: [].into(), + checksum: [].into(), + other_compatible_checksums: [].into(), + }]) + .unwrap(); + + migrator.run_direct(conn).await.unwrap(); + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 2, + description: "".into(), + steps: [].into(), + checksum: [].into(), + other_compatible_checksums: [].into(), + }]) + .unwrap(); + + let err = migrator.run_direct(conn).await.unwrap_err(); + assert_eq!( + err.to_string(), + "migration 1 was previously applied but is missing in the resolved migrations" + ); + } + + #[tokio::test] + async fn test_migrator_fail_dirty_migration_missing() { + maybe_skip_integration!(); + let mut conn = setup().await; + let conn = &mut *conn; + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 1, + description: "".into(), + steps: [IOxMigrationStep::SqlStatement { + sql: "foo".into(), + in_transaction: false, + }] + .into(), + checksum: [].into(), + other_compatible_checksums: [].into(), + }]) + .unwrap(); + + migrator.run_direct(conn).await.unwrap_err(); + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 2, + description: "".into(), + steps: [].into(), + checksum: [].into(), + other_compatible_checksums: [].into(), + }]) + .unwrap(); + + let err = migrator.run_direct(conn).await.unwrap_err(); + assert_eq!( + err.to_string(), + "migration 1 was previously applied but is missing in the resolved migrations" + ); + } + + #[tokio::test] + async fn test_migrator_fail_clean_checksum_mismatch() { + maybe_skip_integration!(); + let mut conn = setup().await; + let conn = &mut *conn; + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 1, + description: "".into(), + steps: [].into(), + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }]) + .unwrap(); + + migrator.run_direct(conn).await.unwrap(); + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 1, + description: "".into(), + steps: [].into(), + checksum: [4, 5, 6].into(), + other_compatible_checksums: [].into(), + }]) + .unwrap(); + + let err = migrator.run_direct(conn).await.unwrap_err(); + assert_eq!( + err.to_string(), + "migration 1 was previously applied but has been modified" + ); + } + + #[tokio::test] + async fn test_migrator_fail_dirty_checksum_mismatch() { + maybe_skip_integration!(); + let mut conn = setup().await; + let conn = &mut *conn; + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 1, + description: "".into(), + steps: [IOxMigrationStep::SqlStatement { + sql: "foo".into(), + in_transaction: false, + }] + .into(), + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }]) + .unwrap(); + + migrator.run_direct(conn).await.unwrap_err(); + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 1, + description: "".into(), + steps: [IOxMigrationStep::SqlStatement { + sql: "foo".into(), + in_transaction: false, + }] + .into(), + checksum: [4, 5, 6].into(), + other_compatible_checksums: [].into(), + }]) + .unwrap(); + + let err = migrator.run_direct(conn).await.unwrap_err(); + assert_eq!( + err.to_string(), + "migration 1 was previously applied but has been modified" + ); + } + + #[tokio::test] + async fn test_migrator_other_compatible_checksum() { + maybe_skip_integration!(); + let mut conn = setup().await; + let conn = &mut *conn; + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 1, + description: "".into(), + steps: [].into(), + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }]) + .unwrap(); + + migrator.run_direct(conn).await.unwrap(); + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 1, + description: "".into(), + steps: [].into(), + checksum: [4, 5, 6].into(), + other_compatible_checksums: [[1, 2, 3].into()].into(), + }]) + .unwrap(); + + migrator.run_direct(conn).await.unwrap(); + } + + /// Migrations may have the same checksum. + /// + /// This is helpful if you want to revert a change later, e.g.: + /// + /// 1. add a index + /// 2. remove the index + /// 3. decide that you actually need the index again + #[tokio::test] + async fn test_migrator_migrations_can_have_same_checksum() { + maybe_skip_integration!(); + let mut conn = setup().await; + let conn = &mut *conn; + + let migrator = IOxMigrator::try_new([ + IOxMigration { + version: 1, + description: "".into(), + steps: [].into(), + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }, + IOxMigration { + version: 2, + description: "".into(), + steps: [].into(), + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }, + ]) + .unwrap(); + + let applied = migrator.run_direct(conn).await.unwrap(); + assert_eq!(applied, HashSet::from([1, 2])); + } + + #[tokio::test] + async fn test_migrator_recover_dirty_same() { + test_migrator_recover_dirty_inner(RecoverFromDirtyMode::Same).await; + } + + #[tokio::test] + async fn test_migrator_recover_dirty_fix_non_transactional() { + test_migrator_recover_dirty_inner(RecoverFromDirtyMode::FixNonTransactional).await; + } + + #[tokio::test] + async fn test_migrator_recover_dirty_fix_transactional() { + test_migrator_recover_dirty_inner(RecoverFromDirtyMode::FixTransactional).await; + } + + /// Modes for [`test_migrator_recover_dirty_inner`] + #[derive(Debug)] + enum RecoverFromDirtyMode { + /// Recover from a fluke. + /// + /// The checksum of the migration stays the same and it is non-transactional (otherwise we wouldn't have + /// ended up in a dirty state to begin with). + Same, + + /// Recover using a fixed version, the fix is still non-transactional. + FixNonTransactional, + + /// Recover using a fixed version, the fix is transactional (in contrast to the original version). + FixTransactional, + } + + impl RecoverFromDirtyMode { + fn same_checksum(&self) -> bool { + match self { + Self::Same => true, + Self::FixNonTransactional => false, + Self::FixTransactional => false, + } + } + + fn fix_is_transactional(&self) -> bool { + match self { + Self::Same => false, + Self::FixNonTransactional => false, + Self::FixTransactional => true, + } + } + } + + async fn test_migrator_recover_dirty_inner(mode: RecoverFromDirtyMode) { + maybe_skip_integration!(); + let mut conn = setup().await; + let conn = &mut *conn; + + conn.execute("CREATE TABLE t (x INT);").await.unwrap(); + let test_query = "SELECT COALESCE(SUM(x), 0)::INT AS r FROM t;"; + + let steps_ok = vec![ + IOxMigrationStep::SqlStatement { + sql: "INSERT INTO t VALUES (1);".into(), + // set to NO transaction, otherwise the migrator will happily wrap the migration bookkeeping and the + // migration script itself into a single transaction to avoid the "dirty" state + in_transaction: mode.fix_is_transactional(), + }, + IOxMigrationStep::SqlStatement { + sql: "INSERT INTO t VALUES (2);".into(), + in_transaction: mode.fix_is_transactional(), + }, + ]; + + let mut steps_broken = steps_ok.clone(); + steps_broken[0] = IOxMigrationStep::SqlStatement { + sql: "foo".into(), + // set to NO transaction, otherwise the migrator will happily wrap the migration bookkeeping and the + // migration script itself into a single transaction to avoid the "dirty" state + in_transaction: false, + }; + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 1, + description: "".into(), + steps: steps_broken.into(), + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }]) + .unwrap(); + + migrator.run_direct(conn).await.unwrap_err(); + + let r: i32 = query_scalar(test_query) + .fetch_one(&mut *conn) + .await + .unwrap(); + assert_eq!(r, 0); + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 1, + description: "".into(), + steps: steps_ok.into(), + checksum: if mode.same_checksum() { + [1, 2, 3].into() + } else { + [4, 5, 6].into() + }, + other_compatible_checksums: if mode.same_checksum() { + [].into() + } else { + [[1, 2, 3].into()].into() + }, + }]) + .unwrap(); + + let applied = migrator.run_direct(conn).await.unwrap(); + assert_eq!(applied, HashSet::from([1])); + + let r: i32 = query_scalar(test_query) + .fetch_one(&mut *conn) + .await + .unwrap(); + assert_eq!(r, 3); + } + + #[tokio::test] + async fn test_migrator_uses_single_transaction_when_possible() { + maybe_skip_integration!(); + let mut conn = setup().await; + let conn = &mut *conn; + + conn.execute("CREATE TABLE t (x INT);").await.unwrap(); + + let steps_ok = vec![ + IOxMigrationStep::SqlStatement { + sql: "INSERT INTO t VALUES (1);".into(), + in_transaction: true, + }, + IOxMigrationStep::SqlStatement { + sql: "INSERT INTO t VALUES (2);".into(), + in_transaction: true, + }, + IOxMigrationStep::SqlStatement { + sql: "INSERT INTO t VALUES (3);".into(), + in_transaction: true, + }, + ]; + + // break in-between step that is sandwiched by two valid ones + let mut steps_broken = steps_ok.clone(); + steps_broken[1] = IOxMigrationStep::SqlStatement { + sql: "foo".into(), + in_transaction: true, + }; + + let test_query = "SELECT COALESCE(SUM(x), 0)::INT AS r FROM t;"; + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 1, + description: "".into(), + steps: steps_broken.into(), + // use a placeholder checksum (normally this would be calculated based on the steps) + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }]) + .unwrap(); + migrator.run_direct(conn).await.unwrap_err(); + + // all or nothing: nothing + let r: i32 = query_scalar(test_query) + .fetch_one(&mut *conn) + .await + .unwrap(); + assert_eq!(r, 0); + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 1, + description: "".into(), + steps: steps_ok.into(), + // same checksum, but now w/ valid steps (to simulate a once failed SQL statement) + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }]) + .unwrap(); + + let applied = migrator.run_direct(conn).await.unwrap(); + assert_eq!(applied, HashSet::from([1]),); + + // all or nothing: all + let r: i32 = query_scalar(test_query).fetch_one(conn).await.unwrap(); + assert_eq!(r, 6); + } + + /// Tests that `CREATE INDEX CONCURRENTLY` doesn't deadlock. + /// + /// Originally we used SQLx to acquire the locks which uses `pg_advisory_lock`. However this seems to acquire a + /// global "shared lock". Other migration frameworks faced the same issue and use `pg_try_advisory_lock` + /// instead. Also see: + /// + /// - + /// - + #[tokio::test] + async fn test_locking() { + const N_TABLES_AND_INDICES: usize = 10; + const N_CONCURRENT_MIGRATIONS: usize = 100; + + maybe_skip_integration!(); + maybe_start_logging(); + let pool = setup_pool().await; + + let migrator = Arc::new( + IOxMigrator::try_new((0..N_TABLES_AND_INDICES).map(|i| { + IOxMigration { + version: i as i64, + description: "".into(), + steps: [ + IOxMigrationStep::SqlStatement { + sql: format!("CREATE TABLE t{i} (x INT);").into(), + in_transaction: false, + }, + IOxMigrationStep::SqlStatement { + sql: format!("CREATE INDEX CONCURRENTLY i{i} ON t{i} (x);").into(), + in_transaction: false, + }, + ] + .into(), + checksum: [].into(), + other_compatible_checksums: [].into(), + } + })) + .unwrap(), + ); + + let mut futures: FuturesUnordered<_> = (0..N_CONCURRENT_MIGRATIONS) + .map(move |_| { + let migrator = Arc::clone(&migrator); + let pool = pool.clone(); + async move { + // pool might timeout, so add another retry loop around it + let mut conn = loop { + let pool = pool.clone(); + if let Ok(conn) = pool.acquire().await { + break conn; + } + }; + let conn = &mut *conn; + migrator.run_direct(conn).await.unwrap(); + } + }) + .collect(); + while futures.next().await.is_some() {} + } + + /// This tests that: + /// + /// - indexes are sanity-checked + /// - sanity checks are applied after each new/dirty migration and we keep the migration dirty until the checks + /// pass + /// - we can manually recover the database and make the non-idempotent migration pass + #[tokio::test] + async fn test_sanity_checks_index_1() { + maybe_skip_integration!(); + let mut conn = setup().await; + let conn = &mut *conn; + + conn.execute("CREATE TABLE t (x INT, y INT);") + .await + .unwrap(); + conn.execute("INSERT INTO t VALUES (1, 1);").await.unwrap(); + conn.execute("INSERT INTO t VALUES (1, 2);").await.unwrap(); + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 1, + description: "".into(), + steps: [IOxMigrationStep::SqlStatement { + sql: "CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS i ON t (x);".into(), + in_transaction: false, + }] + .into(), + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }]) + .unwrap(); + + // fails because is not unique + let err = migrator.run_direct(conn).await.unwrap_err(); + assert_eq!( + err.to_string(), + "while executing migrations: error returned from database: could not create unique index \"i\"" + ); + + // re-applying fails due to sanity checks + // NOTE: Even though the actual migration script passes, the sanity checks DO NOT and hence the migration is + // still considered dirty. It will be re-applied after the manual intervention below. + let err = migrator.run_direct(conn).await.unwrap_err(); + assert_eq!( + err.to_string(), + "while resolving migrations: Found invalid indexes: i" + ); + + // fix data and wipe index + conn.execute("DELETE FROM t WHERE y = 2;").await.unwrap(); + conn.execute("DROP INDEX i;").await.unwrap(); + + // applying works + let applied = migrator.run_direct(conn).await.unwrap(); + assert_eq!(HashSet::from([1]), applied); + } + + /// This tests that: + /// + /// - indexes are sanity-checked + /// - we can fix a data error and a proper, idempotent migration will eventually pass + #[tokio::test] + async fn test_sanity_checks_index_2() { + maybe_skip_integration!(); + let mut conn = setup().await; + let conn = &mut *conn; + + conn.execute("CREATE TABLE t (x INT, y INT);") + .await + .unwrap(); + conn.execute("INSERT INTO t VALUES (1, 1);").await.unwrap(); + conn.execute("INSERT INTO t VALUES (1, 2);").await.unwrap(); + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 1, + description: "".into(), + steps: [ + IOxMigrationStep::SqlStatement { + sql: "DROP INDEX IF EXISTS i;".into(), + in_transaction: false, + }, + IOxMigrationStep::SqlStatement { + sql: "CREATE UNIQUE INDEX CONCURRENTLY i ON t (x);".into(), + in_transaction: false, + }, + ] + .into(), + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }]) + .unwrap(); + + // fails because is not unique + let err = migrator.run_direct(conn).await.unwrap_err(); + assert_eq!( + err.to_string(), + "while executing migrations: error returned from database: could not create unique index \"i\"" + ); + + // re-applying fails with same error (index is wiped but fails w/ same error) + let err = migrator.run_direct(conn).await.unwrap_err(); + assert_eq!( + err.to_string(), + "while executing migrations: error returned from database: could not create unique index \"i\"" + ); + + // fix data issue + conn.execute("UPDATE t SET x = 2 WHERE y = 2") + .await + .unwrap(); + + // now it works + let applied = migrator.run_direct(conn).await.unwrap(); + assert_eq!(HashSet::from([1]), applied); + } + + #[tokio::test] + async fn test_migrator_fail_new_migration_before_applied() { + maybe_skip_integration!(); + let mut conn = setup().await; + let conn = &mut *conn; + + let migration_1 = IOxMigration { + version: 1, + description: "".into(), + steps: [].into(), + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }; + let migration_2 = IOxMigration { + version: 2, + description: "".into(), + steps: [].into(), + checksum: [4, 5, 6].into(), + other_compatible_checksums: [].into(), + }; + + let migrator = IOxMigrator::try_new([migration_2.clone()]).unwrap(); + + let applied = migrator.run_direct(conn).await.unwrap(); + assert_eq!(HashSet::from([2]), applied); + + let migrator = IOxMigrator::try_new([migration_1, migration_2]).unwrap(); + + let err = migrator.run_direct(conn).await.unwrap_err(); + + assert_eq!( + err.to_string(), + "while resolving migrations: new migration (1) goes before last applied migration (2), \ + this should not have been merged!", + ); + } + + #[tokio::test] + async fn test_migrator_fail_new_migration_before_dirty() { + maybe_skip_integration!(); + let mut conn = setup().await; + let conn = &mut *conn; + + let migration_1 = IOxMigration { + version: 1, + description: "".into(), + steps: [].into(), + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }; + let migration_2 = IOxMigration { + version: 2, + description: "".into(), + steps: [IOxMigrationStep::SqlStatement { + sql: "foo".into(), + in_transaction: false, + }] + .into(), + checksum: [4, 5, 6].into(), + other_compatible_checksums: [].into(), + }; + + let migrator = IOxMigrator::try_new([migration_2.clone()]).unwrap(); + + migrator.run_direct(conn).await.unwrap_err(); + + let migrator = IOxMigrator::try_new([migration_1, migration_2]).unwrap(); + + let err = migrator.run_direct(conn).await.unwrap_err(); + + assert_eq!( + err.to_string(), + "while resolving migrations: new migration (1) goes before dirty version (2), \ + this should not have been merged!", + ); + } + + #[tokio::test] + async fn test_migrator_bug_selftest_multiple_dirty_migrations() { + maybe_skip_integration!(); + let mut conn = setup().await; + let conn = &mut *conn; + + let migrator = IOxMigrator::try_new([ + IOxMigration { + version: 1, + description: "".into(), + steps: [].into(), + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }, + IOxMigration { + version: 2, + description: "".into(), + steps: [].into(), + checksum: [4, 5, 6].into(), + other_compatible_checksums: [].into(), + }, + ]) + .unwrap(); + + migrator.run_direct(conn).await.unwrap(); + + conn.execute("UPDATE _sqlx_migrations SET success = FALSE;") + .await + .unwrap(); + + let err = migrator.run_direct(conn).await.unwrap_err(); + + assert_eq!( + err.to_string(), + "while resolving migrations: there are multiple dirty versions, \ + this should not happen and is considered a bug: [1, 2]", + ); + } + + #[tokio::test] + async fn test_migrator_bug_selftest_applied_after_dirty() { + maybe_skip_integration!(); + let mut conn = setup().await; + let conn = &mut *conn; + + let migrator = IOxMigrator::try_new([ + IOxMigration { + version: 1, + description: "".into(), + steps: [].into(), + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }, + IOxMigration { + version: 2, + description: "".into(), + steps: [].into(), + checksum: [4, 5, 6].into(), + other_compatible_checksums: [].into(), + }, + ]) + .unwrap(); + + migrator.run_direct(conn).await.unwrap(); + + conn.execute("UPDATE _sqlx_migrations SET success = FALSE WHERE version = 1;") + .await + .unwrap(); + + let err = migrator.run_direct(conn).await.unwrap_err(); + + assert_eq!( + err.to_string(), + "while resolving migrations: dirty version (1) is not the last applied version (2), this is a bug", + ); + } + + #[tokio::test] + async fn test_migrator_allows_unknown_migrations_if_they_are_clean() { + maybe_skip_integration!(); + let mut conn = setup().await; + let conn = &mut *conn; + + let migrator_1 = IOxMigrator::try_new([ + IOxMigration { + version: 1, + description: "".into(), + steps: [].into(), + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }, + IOxMigration { + version: 2, + description: "".into(), + steps: [].into(), + checksum: [4, 5, 6].into(), + other_compatible_checksums: [].into(), + }, + ]) + .unwrap(); + let migrator_2 = IOxMigrator::try_new([IOxMigration { + version: 1, + description: "".into(), + steps: [].into(), + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }]) + .unwrap(); + + migrator_1.run_direct(conn).await.unwrap(); + migrator_2.run_direct(conn).await.unwrap(); + } + + #[tokio::test] + async fn test_tester_finds_invalid_migration() { + maybe_skip_integration!(); + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 1, + description: "".into(), + steps: [IOxMigrationStep::SqlStatement { + sql: "foo".into(), + in_transaction: true, + }] + .into(), + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }]) + .unwrap(); + + let err = test_utils::test_migration(&migrator, setup_pool) + .await + .unwrap_err(); + + assert_eq!( + err.to_string(), + "while executing migrations: error returned from database: syntax error at or near \"foo\"", + ); + } + + #[tokio::test] + async fn test_tester_finds_non_idempotent_migration_package() { + maybe_skip_integration!(); + + let migrator = IOxMigrator::try_new([IOxMigration { + version: 1, + description: "".into(), + steps: [IOxMigrationStep::SqlStatement { + sql: "CREATE TABLE t (x INT);".into(), + // do NOT run this in a transaction, otherwise this is automatically idempotent + in_transaction: false, + }] + .into(), + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }]) + .unwrap(); + + let err = test_utils::test_migration(&migrator, setup_pool) + .await + .unwrap_err(); + + assert_eq!( + err.to_string(), + "while executing migrations: error returned from database: relation \"t\" already exists", + ); + } + + #[tokio::test] + async fn test_tester_finds_non_idempotent_migration_step() { + maybe_skip_integration!(); + + let migrator = IOxMigrator::try_new([ + IOxMigration { + version: 1, + description: "".into(), + steps: [IOxMigrationStep::SqlStatement { + sql: "CREATE TABLE t (x INT);".into(), + in_transaction: true, + }] + .into(), + checksum: [1, 2, 3].into(), + other_compatible_checksums: [].into(), + }, + IOxMigration { + version: 2, + description: "".into(), + steps: [ + IOxMigrationStep::SqlStatement { + sql: "DROP TABLE t;".into(), + // do NOT run this in a transaction, otherwise this is automatically idempotent + in_transaction: false, + }, + IOxMigrationStep::SqlStatement { + sql: "CREATE TABLE t (x INT);".into(), + // do NOT run this in a transaction, otherwise this is automatically idempotent + in_transaction: false, + }, + ] + .into(), + checksum: [4, 5, 6].into(), + other_compatible_checksums: [].into(), + }, + ]) + .unwrap(); + + let err = test_utils::test_migration(&migrator, setup_pool) + .await + .unwrap_err(); + + assert_eq!( + err.to_string(), + "while executing migrations: error returned from database: table \"t\" does not exist", + ); + } + + async fn setup_pool() -> HotSwapPool { + maybe_start_logging(); + + setup_db_no_migration().await.into_pool() + } + + async fn setup() -> PoolConnection { + let pool = setup_pool().await; + pool.acquire().await.unwrap() + } + + #[track_caller] + fn assert_starts_with(s: &str, prefix: &str) { + if !s.starts_with(prefix) { + panic!("'{s}' does not start with '{prefix}'"); + } + } + } +} diff --git a/iox_catalog/src/postgres.rs b/iox_catalog/src/postgres.rs new file mode 100644 index 0000000..ef9c5d2 --- /dev/null +++ b/iox_catalog/src/postgres.rs @@ -0,0 +1,2783 @@ +//! A Postgres backed implementation of the Catalog + +use crate::interface::PartitionRepoExt; +use crate::{ + constants::{ + MAX_PARQUET_FILES_SELECTED_ONCE_FOR_DELETE, MAX_PARQUET_FILES_SELECTED_ONCE_FOR_RETENTION, + }, + interface::{ + AlreadyExistsSnafu, CasFailure, Catalog, ColumnRepo, Error, NamespaceRepo, ParquetFileRepo, + PartitionRepo, RepoCollection, Result, SoftDeletedRows, TableRepo, + }, + metrics::MetricDecorator, + migrate::IOxMigrator, +}; +use async_trait::async_trait; +use data_types::snapshot::partition::PartitionSnapshot; +use data_types::snapshot::table::TableSnapshot; +use data_types::{ + partition_template::{ + NamespacePartitionTemplateOverride, TablePartitionTemplateOverride, TemplatePart, + }, + Column, ColumnType, CompactionLevel, MaxColumnsPerTable, MaxTables, Namespace, NamespaceId, + NamespaceName, NamespaceServiceProtectionLimitsOverride, ObjectStoreId, ParquetFile, + ParquetFileId, ParquetFileParams, Partition, PartitionHashId, PartitionId, PartitionKey, + SkippedCompaction, SortKeyIds, Table, TableId, Timestamp, +}; +use iox_time::{SystemProvider, TimeProvider}; +use metric::{Attributes, Instrument, MetricKind}; +use observability_deps::tracing::{debug, info, warn}; +use once_cell::sync::Lazy; +use parking_lot::{RwLock, RwLockWriteGuard}; +use snafu::prelude::*; +use sqlx::{ + postgres::{PgConnectOptions, PgPoolOptions}, + Acquire, ConnectOptions, Executor, Postgres, Row, +}; +use sqlx_hotswap_pool::HotSwapPool; +use std::{ + borrow::Cow, + collections::{HashMap, HashSet}, + env, + fmt::Display, + str::FromStr, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::Duration, +}; + +static MIGRATOR: Lazy = + Lazy::new(|| IOxMigrator::try_from(&sqlx::migrate!()).expect("valid migration")); + +/// Postgres connection options. +#[derive(Debug, Clone)] +pub struct PostgresConnectionOptions { + /// Application name. + /// + /// This will be reported to postgres. + pub app_name: String, + + /// Schema name. + pub schema_name: String, + + /// DSN. + pub dsn: String, + + /// Maximum number of concurrent connections. + pub max_conns: u32, + + /// Set the amount of time to attempt connecting to the database. + pub connect_timeout: Duration, + + /// Set a maximum idle duration for individual connections. + pub idle_timeout: Duration, + + /// If the DSN points to a file (i.e. starts with `dsn-file://`), this sets the interval how often the the file + /// should be polled for updates. + /// + /// If an update is encountered, the underlying connection pool will be hot-swapped. + pub hotswap_poll_interval: Duration, +} + +impl PostgresConnectionOptions { + /// Default value for [`schema_name`](Self::schema_name). + pub const DEFAULT_SCHEMA_NAME: &'static str = "iox_catalog"; + + /// Default value for [`max_conns`](Self::max_conns). + pub const DEFAULT_MAX_CONNS: u32 = 10; + + /// Default value for [`connect_timeout`](Self::connect_timeout). + pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(2); + + /// Default value for [`idle_timeout`](Self::idle_timeout). + pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(10); + + /// Default value for [`hotswap_poll_interval`](Self::hotswap_poll_interval). + pub const DEFAULT_HOTSWAP_POLL_INTERVAL: Duration = Duration::from_secs(5); +} + +impl Default for PostgresConnectionOptions { + fn default() -> Self { + Self { + app_name: String::from("iox"), + schema_name: String::from(Self::DEFAULT_SCHEMA_NAME), + dsn: String::new(), + max_conns: Self::DEFAULT_MAX_CONNS, + connect_timeout: Self::DEFAULT_CONNECT_TIMEOUT, + idle_timeout: Self::DEFAULT_IDLE_TIMEOUT, + hotswap_poll_interval: Self::DEFAULT_HOTSWAP_POLL_INTERVAL, + } + } +} + +/// PostgreSQL catalog. +#[derive(Debug)] +pub struct PostgresCatalog { + metrics: Arc, + pool: HotSwapPool, + time_provider: Arc, + // Connection options for display + options: PostgresConnectionOptions, +} + +impl PostgresCatalog { + /// Connect to the catalog store. + pub async fn connect( + options: PostgresConnectionOptions, + metrics: Arc, + ) -> Result { + let pool = new_pool(&options, Arc::clone(&metrics)).await?; + + Ok(Self { + pool, + metrics, + time_provider: Arc::new(SystemProvider::new()), + options, + }) + } + + fn schema_name(&self) -> &str { + &self.options.schema_name + } + + #[cfg(test)] + pub(crate) fn into_pool(self) -> HotSwapPool { + self.pool + } +} + +impl Display for PostgresCatalog { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + // Do not include dsn in log as it may have credentials + // that should not end up in the log + "Postgres(dsn=OMITTED, schema_name='{}')", + self.schema_name() + ) + } +} + +/// transaction for [`PostgresCatalog`]. +#[derive(Debug)] +pub struct PostgresTxn { + inner: PostgresTxnInner, + time_provider: Arc, +} + +#[derive(Debug)] +struct PostgresTxnInner { + pool: HotSwapPool, +} + +impl<'c> Executor<'c> for &'c mut PostgresTxnInner { + type Database = Postgres; + + #[allow(clippy::type_complexity)] + fn fetch_many<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> futures::stream::BoxStream< + 'e, + Result< + sqlx::Either< + ::QueryResult, + ::Row, + >, + sqlx::Error, + >, + > + where + 'c: 'e, + E: sqlx::Execute<'q, Self::Database>, + { + self.pool.fetch_many(query) + } + + fn fetch_optional<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> futures::future::BoxFuture< + 'e, + Result::Row>, sqlx::Error>, + > + where + 'c: 'e, + E: sqlx::Execute<'q, Self::Database>, + { + self.pool.fetch_optional(query) + } + + fn prepare_with<'e, 'q: 'e>( + self, + sql: &'q str, + parameters: &'e [::TypeInfo], + ) -> futures::future::BoxFuture< + 'e, + Result<>::Statement, sqlx::Error>, + > + where + 'c: 'e, + { + self.pool.prepare_with(sql, parameters) + } + + fn describe<'e, 'q: 'e>( + self, + sql: &'q str, + ) -> futures::future::BoxFuture<'e, Result, sqlx::Error>> + where + 'c: 'e, + { + self.pool.describe(sql) + } +} + +#[async_trait] +impl Catalog for PostgresCatalog { + async fn setup(&self) -> Result<(), Error> { + // We need to create the schema if we're going to set it as the first item of the + // search_path otherwise when we run the sqlx migration scripts for the first time, sqlx + // will create the `_sqlx_migrations` table in the public namespace (the only namespace + // that exists), but the second time it will create it in the `` namespace and + // re-run all the migrations without skipping the ones already applied (see #3893). + // + // This makes the migrations/20210217134322_create_schema.sql step unnecessary; we need to + // keep that file because migration files are immutable. + let create_schema_query = format!("CREATE SCHEMA IF NOT EXISTS {};", self.schema_name()); + self.pool.execute(sqlx::query(&create_schema_query)).await?; + + MIGRATOR.run(&self.pool).await?; + + Ok(()) + } + + fn repositories(&self) -> Box { + Box::new(MetricDecorator::new( + PostgresTxn { + inner: PostgresTxnInner { + pool: self.pool.clone(), + }, + time_provider: Arc::clone(&self.time_provider), + }, + Arc::clone(&self.metrics), + Arc::clone(&self.time_provider), + )) + } + + #[cfg(test)] + fn metrics(&self) -> Arc { + Arc::clone(&self.metrics) + } + + fn time_provider(&self) -> Arc { + Arc::clone(&self.time_provider) + } +} + +/// Adapter to connect sqlx pools with our metrics system. +#[derive(Debug, Clone, Default)] +struct PoolMetrics { + /// Actual shared state. + state: Arc, +} + +/// Inner state of [`PoolMetrics`] that is wrapped into an [`Arc`]. +#[derive(Debug, Default)] +struct PoolMetricsInner { + /// Next pool ID. + pool_id_gen: AtomicU64, + + /// Set of known pools and their ID labels. + /// + /// Note: The pool is internally ref-counted via an [`Arc`]. Holding a reference does NOT prevent it from being closed. + pools: RwLock, sqlx::Pool)>>, +} + +impl PoolMetrics { + /// Create new pool metrics. + fn new(metrics: Arc) -> Self { + metrics.register_instrument("iox_catalog_postgres", Self::default) + } + + /// Register a new pool. + fn register_pool(&self, pool: sqlx::Pool) { + let id = self + .state + .pool_id_gen + .fetch_add(1, Ordering::SeqCst) + .to_string() + .into(); + let mut pools = self.state.pools.write(); + pools.push((id, pool)); + } + + /// Remove closed pools from given list. + fn clean_pools(pools: &mut Vec<(Arc, sqlx::Pool)>) { + pools.retain(|(_id, p)| !p.is_closed()); + } +} + +impl Instrument for PoolMetrics { + fn report(&self, reporter: &mut dyn metric::Reporter) { + let mut pools = self.state.pools.write(); + Self::clean_pools(&mut pools); + let pools = RwLockWriteGuard::downgrade(pools); + + reporter.start_metric( + "sqlx_postgres_pools", + "Number of pools that sqlx uses", + MetricKind::U64Gauge, + ); + reporter.report_observation( + &Attributes::from([]), + metric::Observation::U64Gauge(pools.len() as u64), + ); + reporter.finish_metric(); + + reporter.start_metric( + "sqlx_postgres_connections", + "Number of connections within the postgres connection pool that sqlx uses", + MetricKind::U64Gauge, + ); + for (id, p) in pools.iter() { + let active = p.size() as u64; + let idle = p.num_idle() as u64; + + // We get both values independently (from underlying atomic counters) so they might be out of sync (with a + // low likelyhood). Calculating this value and emitting it is useful though since it allows easier use in + // dashboards since you can `max_over_time` w/o any recording rules. + let used = active.saturating_sub(idle); + + reporter.report_observation( + &Attributes::from([ + ("pool_id", Cow::Owned(id.as_ref().to_owned())), + ("state", Cow::Borrowed("active")), + ]), + metric::Observation::U64Gauge(active), + ); + reporter.report_observation( + &Attributes::from([ + ("pool_id", Cow::Owned(id.as_ref().to_owned())), + ("state", Cow::Borrowed("idle")), + ]), + metric::Observation::U64Gauge(idle), + ); + reporter.report_observation( + &Attributes::from([ + ("pool_id", Cow::Owned(id.as_ref().to_owned())), + ("state", Cow::Borrowed("used")), + ]), + metric::Observation::U64Gauge(used), + ); + reporter.report_observation( + &Attributes::from([ + ("pool_id", Cow::Owned(id.as_ref().to_owned())), + ("state", Cow::Borrowed("max")), + ]), + metric::Observation::U64Gauge(p.options().get_max_connections() as u64), + ); + reporter.report_observation( + &Attributes::from([ + ("pool_id", Cow::Owned(id.as_ref().to_owned())), + ("state", Cow::Borrowed("min")), + ]), + metric::Observation::U64Gauge(p.options().get_min_connections() as u64), + ); + } + + reporter.finish_metric(); + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +/// Creates a new [`sqlx::Pool`] from a database config and an explicit DSN. +/// +/// This function doesn't support the IDPE specific `dsn-file://` uri scheme. +async fn new_raw_pool( + options: &PostgresConnectionOptions, + parsed_dsn: &str, + metrics: PoolMetrics, +) -> Result, sqlx::Error> { + // sqlx exposes some options as pool options, while other options are available as connection options. + let mut connect_options = PgConnectOptions::from_str(parsed_dsn)? + // the default is INFO, which is frankly surprising. + .log_statements(log::LevelFilter::Trace); + + // Workaround sqlx ignoring the SSL_CERT_FILE environment variable. + // Remove workaround when upstream sqlx handles SSL_CERT_FILE properly (#8994). + let cert_file = env::var("SSL_CERT_FILE").unwrap_or_default(); + if !cert_file.is_empty() { + connect_options = connect_options.ssl_root_cert(cert_file); + } + + let app_name = options.app_name.clone(); + let app_name2 = options.app_name.clone(); // just to log below + let schema_name = options.schema_name.clone(); + let pool = PgPoolOptions::new() + .min_connections(1) + .max_connections(options.max_conns) + .acquire_timeout(options.connect_timeout) + .idle_timeout(options.idle_timeout) + .test_before_acquire(true) + .after_connect(move |c, _meta| { + let app_name = app_name.to_owned(); + let schema_name = schema_name.to_owned(); + Box::pin(async move { + // Tag the connection with the provided application name, while allowing it to + // be override from the connection string (aka DSN). + // If current_application_name is empty here it means the application name wasn't + // set as part of the DSN, and we can set it explicitly. + // Recall that this block is running on connection, not when creating the pool! + let current_application_name: String = + sqlx::query_scalar("SELECT current_setting('application_name');") + .fetch_one(&mut *c) + .await?; + if current_application_name.is_empty() { + sqlx::query("SELECT set_config('application_name', $1, false);") + .bind(&*app_name) + .execute(&mut *c) + .await?; + } + let search_path_query = format!("SET search_path TO {schema_name},public;"); + c.execute(sqlx::query(&search_path_query)).await?; + + // Ensure explicit timezone selection, instead of deferring to + // the server value. + c.execute("SET timezone = 'UTC';").await?; + Ok(()) + }) + }) + .connect_with(connect_options) + .await?; + + // Log a connection was successfully established and include the application + // name for cross-correlation between Conductor logs & database connections. + info!(application_name=%app_name2, "connected to config store"); + + metrics.register_pool(pool.clone()); + Ok(pool) +} + +/// Parse a postgres catalog dsn, handling the special `dsn-file://` +/// syntax (see [`new_pool`] for more details). +/// +/// Returns an error if the dsn-file could not be read correctly. +pub fn parse_dsn(dsn: &str) -> Result { + let dsn = match get_dsn_file_path(dsn) { + Some(filename) => std::fs::read_to_string(filename)?, + None => dsn.to_string(), + }; + Ok(dsn) +} + +/// Creates a new HotSwapPool +/// +/// This function understands the IDPE specific `dsn-file://` dsn uri scheme +/// and hot swaps the pool with a new sqlx::Pool when the file changes. +/// This is useful because the credentials can be rotated by infrastructure +/// agents while the service is running. +/// +/// The file is polled for changes every `polling_interval`. +/// +/// The pool is replaced only once the new pool is successfully created. +/// The [`new_raw_pool`] function will return a new pool only if the connection +/// is successfull (see [`sqlx::pool::PoolOptions::test_before_acquire`]). +async fn new_pool( + options: &PostgresConnectionOptions, + metrics: Arc, +) -> Result, sqlx::Error> { + let parsed_dsn = parse_dsn(&options.dsn)?; + let metrics = PoolMetrics::new(metrics); + let pool = HotSwapPool::new(new_raw_pool(options, &parsed_dsn, metrics.clone()).await?); + let polling_interval = options.hotswap_poll_interval; + + if let Some(dsn_file) = get_dsn_file_path(&options.dsn) { + let pool = pool.clone(); + let options = options.clone(); + + // TODO(mkm): return a guard that stops this background worker. + // We create only one pool per process, but it would be cleaner to be + // able to properly destroy the pool. If we don't kill this worker we + // effectively keep the pool alive (since it holds a reference to the + // Pool) and we also potentially pollute the logs with spurious warnings + // if the dsn file disappears (this may be annoying if they show up in the test + // logs). + tokio::spawn(async move { + let mut current_dsn = parsed_dsn.clone(); + loop { + tokio::time::sleep(polling_interval).await; + + async fn try_update( + options: &PostgresConnectionOptions, + current_dsn: &str, + dsn_file: &str, + pool: &HotSwapPool, + metrics: PoolMetrics, + ) -> Result, sqlx::Error> { + let new_dsn = std::fs::read_to_string(dsn_file)?; + if new_dsn == current_dsn { + Ok(None) + } else { + let new_pool = new_raw_pool(options, &new_dsn, metrics).await?; + let old_pool = pool.replace(new_pool); + info!("replaced hotswap pool"); + info!(?old_pool, "closing old DB connection pool"); + // The pool is not closed on drop. We need to call `close`. + // It will close all idle connections, and wait until acquired connections + // are returned to the pool or closed. + old_pool.close().await; + info!(?old_pool, "closed old DB connection pool"); + Ok(Some(new_dsn)) + } + } + + match try_update(&options, ¤t_dsn, &dsn_file, &pool, metrics.clone()).await { + Ok(None) => {} + Ok(Some(new_dsn)) => { + current_dsn = new_dsn; + } + Err(e) => { + warn!( + error=%e, + filename=%dsn_file, + "not replacing hotswap pool because of an error \ + connecting to the new DSN" + ); + } + } + } + }); + } + + Ok(pool) +} + +// Parses a `dsn-file://` scheme, according to the rules of the IDPE kit/sql package. +// +// If the dsn matches the `dsn-file://` prefix, the prefix is removed and the rest is interpreted +// as a file name, in which case this function will return `Some(filename)`. +// Otherwise it will return None. No URI decoding is performed on the filename. +fn get_dsn_file_path(dsn: &str) -> Option { + const DSN_SCHEME: &str = "dsn-file://"; + dsn.starts_with(DSN_SCHEME) + .then(|| dsn[DSN_SCHEME.len()..].to_owned()) +} + +impl RepoCollection for PostgresTxn { + fn namespaces(&mut self) -> &mut dyn NamespaceRepo { + self + } + + fn tables(&mut self) -> &mut dyn TableRepo { + self + } + + fn columns(&mut self) -> &mut dyn ColumnRepo { + self + } + + fn partitions(&mut self) -> &mut dyn PartitionRepo { + self + } + + fn parquet_files(&mut self) -> &mut dyn ParquetFileRepo { + self + } +} + +async fn insert_column_with_connection<'q, E>( + executor: E, + name: &str, + table_id: TableId, + column_type: ColumnType, +) -> Result +where + E: Executor<'q, Database = Postgres>, +{ + let rec = sqlx::query_as::<_, Column>( + r#" +INSERT INTO column_name ( name, table_id, column_type ) +SELECT $1, table_id, $3 FROM ( + SELECT max_columns_per_table, namespace.id, table_name.id as table_id, COUNT(column_name.*) AS count + FROM namespace LEFT JOIN table_name ON namespace.id = table_name.namespace_id + LEFT JOIN column_name ON table_name.id = column_name.table_id + WHERE table_name.id = $2 + GROUP BY namespace.max_columns_per_table, namespace.id, table_name.id +) AS get_count WHERE count < max_columns_per_table +ON CONFLICT ON CONSTRAINT column_name_unique +DO UPDATE SET name = column_name.name +RETURNING *; + "#, + ) + .bind(name) // $1 + .bind(table_id) // $2 + .bind(column_type) // $3 + .fetch_one(executor) + .await + .map_err(|e| match e { + sqlx::Error::RowNotFound => Error::LimitExceeded { + descr: format!("couldn't create column {} in table {}; limit reached on namespace", name, table_id) + }, + _ => { + if is_fk_violation(&e) { + Error::NotFound { descr: e.to_string() } + } else { + Error::External { source: Box::new(e) } + } + }})?; + + ensure!( + rec.column_type == column_type, + AlreadyExistsSnafu { + descr: format!( + "column {} is type {} but schema update has type {}", + name, rec.column_type, column_type + ), + } + ); + + Ok(rec) +} + +#[async_trait] +impl NamespaceRepo for PostgresTxn { + async fn create( + &mut self, + name: &NamespaceName<'_>, + partition_template: Option, + retention_period_ns: Option, + service_protection_limits: Option, + ) -> Result { + let max_tables = service_protection_limits + .and_then(|l| l.max_tables) + .unwrap_or_default(); + let max_columns_per_table = service_protection_limits + .and_then(|l| l.max_columns_per_table) + .unwrap_or_default(); + + let rec = sqlx::query_as::<_, Namespace>( + r#" +INSERT INTO namespace ( + name, retention_period_ns, max_tables, max_columns_per_table, partition_template +) +VALUES ( $1, $2, $3, $4, $5 ) +RETURNING id, name, retention_period_ns, max_tables, max_columns_per_table, deleted_at, + partition_template; + "#, + ) + .bind(name.as_str()) // $1 + .bind(retention_period_ns) // $2 + .bind(max_tables) // $3 + .bind(max_columns_per_table) // $4 + .bind(partition_template); // $5 + + let rec = rec.fetch_one(&mut self.inner).await.map_err(|e| { + if is_unique_violation(&e) { + Error::AlreadyExists { + descr: name.to_string(), + } + } else if is_fk_violation(&e) { + Error::NotFound { + descr: e.to_string(), + } + } else { + Error::External { + source: Box::new(e), + } + } + })?; + + Ok(rec) + } + + async fn list(&mut self, deleted: SoftDeletedRows) -> Result> { + let rec = sqlx::query_as::<_, Namespace>( + format!( + r#" +SELECT id, name, retention_period_ns, max_tables, max_columns_per_table, deleted_at, + partition_template +FROM namespace +WHERE {v}; + "#, + v = deleted.as_sql_predicate() + ) + .as_str(), + ) + .fetch_all(&mut self.inner) + .await?; + + Ok(rec) + } + + async fn get_by_id( + &mut self, + id: NamespaceId, + deleted: SoftDeletedRows, + ) -> Result> { + let rec = sqlx::query_as::<_, Namespace>( + format!( + r#" +SELECT id, name, retention_period_ns, max_tables, max_columns_per_table, deleted_at, + partition_template +FROM namespace +WHERE id=$1 AND {v}; + "#, + v = deleted.as_sql_predicate() + ) + .as_str(), + ) + .bind(id) // $1 + .fetch_one(&mut self.inner) + .await; + + if let Err(sqlx::Error::RowNotFound) = rec { + return Ok(None); + } + + let namespace = rec?; + + Ok(Some(namespace)) + } + + async fn get_by_name( + &mut self, + name: &str, + deleted: SoftDeletedRows, + ) -> Result> { + let rec = sqlx::query_as::<_, Namespace>( + format!( + r#" +SELECT id, name, retention_period_ns, max_tables, max_columns_per_table, deleted_at, + partition_template +FROM namespace +WHERE name=$1 AND {v}; + "#, + v = deleted.as_sql_predicate() + ) + .as_str(), + ) + .bind(name) // $1 + .fetch_one(&mut self.inner) + .await; + + if let Err(sqlx::Error::RowNotFound) = rec { + return Ok(None); + } + + let namespace = rec?; + + Ok(Some(namespace)) + } + + async fn soft_delete(&mut self, name: &str) -> Result<()> { + let flagged_at = Timestamp::from(self.time_provider.now()); + + // note that there is a uniqueness constraint on the name column in the DB + sqlx::query(r#"UPDATE namespace SET deleted_at=$1 WHERE name = $2;"#) + .bind(flagged_at) // $1 + .bind(name) // $2 + .execute(&mut self.inner) + .await + .map_err(Error::from) + .map(|_| ()) + } + + async fn update_table_limit(&mut self, name: &str, new_max: MaxTables) -> Result { + let rec = sqlx::query_as::<_, Namespace>( + r#" +UPDATE namespace +SET max_tables = $1 +WHERE name = $2 +RETURNING id, name, retention_period_ns, max_tables, max_columns_per_table, deleted_at, + partition_template; + "#, + ) + .bind(new_max) + .bind(name) + .fetch_one(&mut self.inner) + .await; + + let namespace = rec.map_err(|e| match e { + sqlx::Error::RowNotFound => Error::NotFound { + descr: name.to_string(), + }, + _ => Error::External { + source: Box::new(e), + }, + })?; + + Ok(namespace) + } + + async fn update_column_limit( + &mut self, + name: &str, + new_max: MaxColumnsPerTable, + ) -> Result { + let rec = sqlx::query_as::<_, Namespace>( + r#" +UPDATE namespace +SET max_columns_per_table = $1 +WHERE name = $2 +RETURNING id, name, retention_period_ns, max_tables, max_columns_per_table, deleted_at, + partition_template; + "#, + ) + .bind(new_max) + .bind(name) + .fetch_one(&mut self.inner) + .await; + + let namespace = rec.map_err(|e| match e { + sqlx::Error::RowNotFound => Error::NotFound { + descr: name.to_string(), + }, + _ => Error::External { + source: Box::new(e), + }, + })?; + + Ok(namespace) + } + + async fn update_retention_period( + &mut self, + name: &str, + retention_period_ns: Option, + ) -> Result { + let rec = sqlx::query_as::<_, Namespace>( + r#" +UPDATE namespace +SET retention_period_ns = $1 +WHERE name = $2 +RETURNING id, name, retention_period_ns, max_tables, max_columns_per_table, deleted_at, + partition_template; + "#, + ) + .bind(retention_period_ns) // $1 + .bind(name) // $2 + .fetch_one(&mut self.inner) + .await; + + let namespace = rec.map_err(|e| match e { + sqlx::Error::RowNotFound => Error::NotFound { + descr: name.to_string(), + }, + _ => Error::External { + source: Box::new(e), + }, + })?; + + Ok(namespace) + } +} + +#[async_trait] +impl TableRepo for PostgresTxn { + async fn create( + &mut self, + name: &str, + partition_template: TablePartitionTemplateOverride, + namespace_id: NamespaceId, + ) -> Result
{ + let mut tx = self.inner.pool.begin().await?; + + // A simple insert statement becomes quite complicated in order to avoid checking the table + // limits in a select and then conditionally inserting (which would be racey). + // + // from https://www.postgresql.org/docs/current/sql-insert.html + // "INSERT inserts new rows into a table. One can insert one or more rows specified by + // value expressions, or zero or more rows resulting from a query." + // By using SELECT rather than VALUES it will insert zero rows if it finds a null in the + // subquery, i.e. if count >= max_tables. fetch_one() will return a RowNotFound error if + // nothing was inserted. Not pretty! + let table = sqlx::query_as::<_, Table>( + r#" +INSERT INTO table_name ( name, namespace_id, partition_template ) +SELECT $1, id, $2 FROM ( + SELECT namespace.id AS id, max_tables, COUNT(table_name.*) AS count + FROM namespace LEFT JOIN table_name ON namespace.id = table_name.namespace_id + WHERE namespace.id = $3 + GROUP BY namespace.max_tables, table_name.namespace_id, namespace.id +) AS get_count WHERE count < max_tables +RETURNING *; + "#, + ) + .bind(name) // $1 + .bind(partition_template) // $2 + .bind(namespace_id) // $3 + .fetch_one(&mut *tx) + .await + .map_err(|e| match e { + sqlx::Error::RowNotFound => Error::LimitExceeded { + descr: format!( + "couldn't create table {}; limit reached on namespace {}", + name, namespace_id + ), + }, + _ => { + if is_unique_violation(&e) { + Error::AlreadyExists { + descr: format!("table '{name}' in namespace {namespace_id}"), + } + } else if is_fk_violation(&e) { + Error::NotFound { + descr: e.to_string(), + } + } else { + Error::External { + source: Box::new(e), + } + } + } + })?; + + // Partitioning is only supported for tags, so create tag columns for all `TagValue` + // partition template parts. It's important this happens within the table creation + // transaction so that there isn't a possibility of a concurrent write creating these + // columns with an unsupported type. + for template_part in table.partition_template.parts() { + if let TemplatePart::TagValue(tag_name) = template_part { + insert_column_with_connection(&mut *tx, tag_name, table.id, ColumnType::Tag) + .await?; + } + } + + tx.commit().await?; + + Ok(table) + } + + async fn get_by_id(&mut self, table_id: TableId) -> Result> { + let rec = sqlx::query_as::<_, Table>( + r#" +SELECT * +FROM table_name +WHERE id = $1; + "#, + ) + .bind(table_id) // $1 + .fetch_one(&mut self.inner) + .await; + + if let Err(sqlx::Error::RowNotFound) = rec { + return Ok(None); + } + + let table = rec?; + + Ok(Some(table)) + } + + async fn get_by_namespace_and_name( + &mut self, + namespace_id: NamespaceId, + name: &str, + ) -> Result> { + let rec = sqlx::query_as::<_, Table>( + r#" +SELECT * +FROM table_name +WHERE namespace_id = $1 AND name = $2; + "#, + ) + .bind(namespace_id) // $1 + .bind(name) // $2 + .fetch_one(&mut self.inner) + .await; + + if let Err(sqlx::Error::RowNotFound) = rec { + return Ok(None); + } + + let table = rec?; + + Ok(Some(table)) + } + + async fn list_by_namespace_id(&mut self, namespace_id: NamespaceId) -> Result> { + let rec = sqlx::query_as::<_, Table>( + r#" +SELECT * +FROM table_name +WHERE namespace_id = $1; + "#, + ) + .bind(namespace_id) + .fetch_all(&mut self.inner) + .await?; + + Ok(rec) + } + + async fn list(&mut self) -> Result> { + let rec = sqlx::query_as::<_, Table>("SELECT * FROM table_name;") + .fetch_all(&mut self.inner) + .await?; + + Ok(rec) + } + + async fn snapshot(&mut self, table_id: TableId) -> Result { + let mut tx = self.inner.pool.begin().await?; + let rec = sqlx::query_as::<_, Table>("SELECT * from table_name WHERE id = $1 FOR UPDATE;") + .bind(table_id) // $1 + .fetch_one(&mut *tx) + .await; + + if let Err(sqlx::Error::RowNotFound) = rec { + return Err(Error::NotFound { + descr: format!("table: {table_id}"), + }); + } + let table = rec?; + + let columns = sqlx::query_as::<_, Column>("SELECT * from column_name where table_id = $1;") + .bind(table_id) // $1 + .fetch_all(&mut *tx) + .await?; + + let partitions = + sqlx::query_as::<_, Partition>(r#"SELECT * FROM partition WHERE table_id = $1;"#) + .bind(table_id) // $1 + .fetch_all(&mut *tx) + .await?; + + let (generation,): (i64,) = sqlx::query_as( + "UPDATE table_name SET generation = generation + 1 where id = $1 RETURNING generation;", + ) + .bind(table_id) // $1 + .fetch_one(&mut *tx) + .await?; + + tx.commit().await?; + + Ok(TableSnapshot::encode( + table, + partitions, + columns, + generation as _, + )?) + } +} + +#[async_trait] +impl ColumnRepo for PostgresTxn { + async fn create_or_get( + &mut self, + name: &str, + table_id: TableId, + column_type: ColumnType, + ) -> Result { + insert_column_with_connection(&mut self.inner, name, table_id, column_type).await + } + + async fn list_by_namespace_id(&mut self, namespace_id: NamespaceId) -> Result> { + let rec = sqlx::query_as::<_, Column>( + r#" +SELECT column_name.* FROM table_name +INNER JOIN column_name on column_name.table_id = table_name.id +WHERE table_name.namespace_id = $1; + "#, + ) + .bind(namespace_id) + .fetch_all(&mut self.inner) + .await?; + + Ok(rec) + } + + async fn list_by_table_id(&mut self, table_id: TableId) -> Result> { + let rec = sqlx::query_as::<_, Column>( + r#" +SELECT * FROM column_name +WHERE table_id = $1; + "#, + ) + .bind(table_id) + .fetch_all(&mut self.inner) + .await?; + + Ok(rec) + } + + async fn list(&mut self) -> Result> { + let rec = sqlx::query_as::<_, Column>("SELECT * FROM column_name;") + .fetch_all(&mut self.inner) + .await?; + + Ok(rec) + } + + async fn create_or_get_many_unchecked( + &mut self, + table_id: TableId, + columns: HashMap<&str, ColumnType>, + ) -> Result> { + let num_columns = columns.len(); + let (v_name, v_column_type): (Vec<&str>, Vec) = columns + .iter() + .map(|(&name, &column_type)| (name, column_type as i16)) + .unzip(); + + // The `ORDER BY` in this statement is important to avoid deadlocks during concurrent + // writes to the same IOx table that each add many new columns. See: + // + // - + // - + // - + let out = sqlx::query_as::<_, Column>( + r#" +INSERT INTO column_name ( name, table_id, column_type ) +SELECT name, $1, column_type +FROM UNNEST($2, $3) as a(name, column_type) +ORDER BY name +ON CONFLICT ON CONSTRAINT column_name_unique +DO UPDATE SET name = column_name.name +RETURNING *; + "#, + ) + .bind(table_id) // $1 + .bind(&v_name) // $2 + .bind(&v_column_type) // $3 + .fetch_all(&mut self.inner) + .await + .map_err(|e| { + if is_fk_violation(&e) { + Error::NotFound { + descr: e.to_string(), + } + } else { + Error::External { + source: Box::new(e), + } + } + })?; + + assert_eq!(num_columns, out.len()); + + for existing in &out { + let want = columns.get(existing.name.as_str()).unwrap(); + ensure!( + existing.column_type == *want, + AlreadyExistsSnafu { + descr: format!( + "column {} is type {} but schema update has type {}", + existing.name, existing.column_type, want + ), + } + ); + } + + Ok(out) + } +} + +#[async_trait] +impl PartitionRepo for PostgresTxn { + async fn create_or_get(&mut self, key: PartitionKey, table_id: TableId) -> Result { + let hash_id = PartitionHashId::new(table_id, &key); + + let v = sqlx::query_as::<_, Partition>( + r#" +INSERT INTO partition + (partition_key, table_id, hash_id, sort_key_ids) +VALUES + ( $1, $2, $3, '{}') +ON CONFLICT ON CONSTRAINT partition_key_unique +DO UPDATE SET partition_key = partition.partition_key +RETURNING id, hash_id, table_id, partition_key, sort_key_ids, new_file_at; + "#, + ) + .bind(&key) // $1 + .bind(table_id) // $2 + .bind(&hash_id) // $3 + .fetch_one(&mut self.inner) + .await + .map_err(|e| { + if is_fk_violation(&e) { + Error::NotFound { + descr: e.to_string(), + } + } else if is_unique_violation(&e) { + // Logging more information to diagnose a production issue maybe + warn!( + error=?e, + %table_id, + %key, + %hash_id, + "possible duplicate partition_hash_id?" + ); + Error::External { + source: Box::new(e), + } + } else { + Error::External { + source: Box::new(e), + } + } + })?; + + Ok(v) + } + + async fn get_by_id_batch(&mut self, partition_ids: &[PartitionId]) -> Result> { + let ids: Vec<_> = partition_ids.iter().map(|p| p.get()).collect(); + + sqlx::query_as::<_, Partition>( + r#" +SELECT id, hash_id, table_id, partition_key, sort_key_ids, new_file_at +FROM partition +WHERE id = ANY($1); + "#, + ) + .bind(&ids[..]) // $1 + .fetch_all(&mut self.inner) + .await + .map_err(Error::from) + } + + async fn list_by_table_id(&mut self, table_id: TableId) -> Result> { + sqlx::query_as::<_, Partition>( + r#" +SELECT id, hash_id, table_id, partition_key, sort_key_ids, new_file_at +FROM partition +WHERE table_id = $1; + "#, + ) + .bind(table_id) // $1 + .fetch_all(&mut self.inner) + .await + .map_err(Error::from) + } + + async fn list_ids(&mut self) -> Result> { + sqlx::query_as( + r#" + SELECT p.id as partition_id + FROM partition p + "#, + ) + .fetch_all(&mut self.inner) + .await + .map_err(Error::from) + } + + /// Update the sort key for `partition_id` if and only if `old_sort_key` + /// matches the current value in the database. + /// + /// This compare-and-swap operation is allowed to spuriously return + /// [`CasFailure::ValueMismatch`] for performance reasons (avoiding multiple + /// round trips to service a transaction in the happy path). + async fn cas_sort_key( + &mut self, + partition_id: PartitionId, + old_sort_key_ids: Option<&SortKeyIds>, + new_sort_key_ids: &SortKeyIds, + ) -> Result> { + let old_sort_key_ids = old_sort_key_ids + .map(std::ops::Deref::deref) + .unwrap_or_default(); + + // This `match` will go away when all partitions have hash IDs in the database. + let query = sqlx::query_as::<_, Partition>( + r#" +UPDATE partition +SET sort_key_ids = $1 +WHERE id = $2 AND sort_key_ids = $3 +RETURNING id, hash_id, table_id, partition_key, sort_key_ids, new_file_at; + "#, + ) + .bind(new_sort_key_ids) // $1 + .bind(partition_id) // $2 + .bind(old_sort_key_ids); // $3; + + let res = query.fetch_one(&mut self.inner).await; + + let partition = match res { + Ok(v) => v, + Err(sqlx::Error::RowNotFound) => { + // This update may have failed either because: + // + // * A row with the specified ID did not exist at query time + // (but may exist now!) + // * The sort key does not match. + // + // To differentiate, we submit a get partition query, returning + // the actual sort key if successful. + // + // NOTE: this is racy, but documented - this might return "Sort + // key differs! Old key: " + let partition = (self as &mut dyn PartitionRepo) + .get_by_id(partition_id) + .await + .map_err(CasFailure::QueryError)? + .ok_or(CasFailure::QueryError(Error::NotFound { + descr: partition_id.to_string(), + }))?; + return Err(CasFailure::ValueMismatch( + partition.sort_key_ids().cloned().unwrap_or_default(), + )); + } + Err(e) => { + return Err(CasFailure::QueryError(Error::External { + source: Box::new(e), + })) + } + }; + + debug!( + ?partition_id, + ?new_sort_key_ids, + "partition sort key cas successful" + ); + + Ok(partition) + } + + async fn record_skipped_compaction( + &mut self, + partition_id: PartitionId, + reason: &str, + num_files: usize, + limit_num_files: usize, + limit_num_files_first_in_partition: usize, + estimated_bytes: u64, + limit_bytes: u64, + ) -> Result<()> { + sqlx::query( + r#" +INSERT INTO skipped_compactions + ( partition_id, reason, num_files, limit_num_files, limit_num_files_first_in_partition, estimated_bytes, limit_bytes, skipped_at ) +VALUES + ( $1, $2, $3, $4, $5, $6, $7, extract(epoch from NOW()) ) +ON CONFLICT ( partition_id ) +DO UPDATE +SET +reason = EXCLUDED.reason, +num_files = EXCLUDED.num_files, +limit_num_files = EXCLUDED.limit_num_files, +limit_num_files_first_in_partition = EXCLUDED.limit_num_files_first_in_partition, +estimated_bytes = EXCLUDED.estimated_bytes, +limit_bytes = EXCLUDED.limit_bytes, +skipped_at = EXCLUDED.skipped_at; + "#, + ) + .bind(partition_id) // $1 + .bind(reason) + .bind(num_files as i64) + .bind(limit_num_files as i64) + .bind(limit_num_files_first_in_partition as i64) + .bind(estimated_bytes as i64) + .bind(limit_bytes as i64) + .execute(&mut self.inner) + .await?; + Ok(()) + } + + async fn get_in_skipped_compactions( + &mut self, + partition_ids: &[PartitionId], + ) -> Result> { + let rec = sqlx::query_as::<_, SkippedCompaction>( + r#"SELECT * FROM skipped_compactions WHERE partition_id = ANY($1);"#, + ) + .bind(partition_ids) // $1 + .fetch_all(&mut self.inner) + .await; + + if let Err(sqlx::Error::RowNotFound) = rec { + return Ok(Vec::new()); + } + + let skipped_partition_records = rec?; + + Ok(skipped_partition_records) + } + + async fn list_skipped_compactions(&mut self) -> Result> { + sqlx::query_as::<_, SkippedCompaction>( + r#" +SELECT * FROM skipped_compactions + "#, + ) + .fetch_all(&mut self.inner) + .await + .map_err(Error::from) + } + + async fn delete_skipped_compactions( + &mut self, + partition_id: PartitionId, + ) -> Result> { + sqlx::query_as::<_, SkippedCompaction>( + r#" +DELETE FROM skipped_compactions +WHERE partition_id = $1 +RETURNING * + "#, + ) + .bind(partition_id) + .fetch_optional(&mut self.inner) + .await + .map_err(Error::from) + } + + async fn most_recent_n(&mut self, n: usize) -> Result> { + sqlx::query_as( + r#" +SELECT id, hash_id, table_id, partition_key, sort_key_ids, new_file_at +FROM partition +ORDER BY id DESC +LIMIT $1;"#, + ) + .bind(n as i64) // $1 + .fetch_all(&mut self.inner) + .await + .map_err(Error::from) + } + + async fn partitions_new_file_between( + &mut self, + minimum_time: Timestamp, + maximum_time: Option, + ) -> Result> { + let sql = format!( + r#" + SELECT p.id as partition_id + FROM partition p + WHERE p.new_file_at > $1 + {} + "#, + maximum_time + .map(|_| "AND p.new_file_at < $2") + .unwrap_or_default() + ); + + sqlx::query_as(&sql) + .bind(minimum_time) // $1 + .bind(maximum_time) // $2 + .fetch_all(&mut self.inner) + .await + .map_err(Error::from) + } + + async fn list_old_style(&mut self) -> Result> { + // Correctness: the main caller of this function, the partition bloom + // filter, relies on all partitions being made available to it. + // + // This function MUST return the full set of old partitions to the + // caller - do NOT apply a LIMIT to this query. + // + // The load this query saves vastly outsizes the load this query causes. + sqlx::query_as( + r#" +SELECT id, hash_id, table_id, partition_key, sort_key_ids, new_file_at +FROM partition +WHERE hash_id IS NULL +ORDER BY id DESC;"#, + ) + .fetch_all(&mut self.inner) + .await + .map_err(Error::from) + } + + async fn snapshot(&mut self, partition_id: PartitionId) -> Result { + let mut tx = self.inner.pool.begin().await?; + + let rec = + sqlx::query_as::<_, Partition>("SELECT * from partition WHERE id = $1 FOR UPDATE;") + .bind(partition_id) // $1 + .fetch_one(&mut *tx) + .await; + if let Err(sqlx::Error::RowNotFound) = rec { + return Err(Error::NotFound { + descr: format!("partition: {partition_id}"), + }); + } + let partition = rec?; + + let files = + sqlx::query_as::<_, ParquetFile>("SELECT * from parquet_file where partition_id = $1 AND parquet_file.to_delete IS NULL;") + .bind(partition_id) // $1 + .fetch_all(&mut *tx) + .await?; + + let sc = sqlx::query_as::<_, SkippedCompaction>( + r#"SELECT * FROM skipped_compactions WHERE partition_id = $1;"#, + ) + .bind(partition_id) // $1 + .fetch_optional(&mut *tx) + .await?; + + let (generation, namespace_id): (i64,NamespaceId) = sqlx::query_as( + "UPDATE partition SET generation = partition.generation + 1 from table_name where partition.id = $1 and table_name.id = partition.table_id RETURNING partition.generation, table_name.namespace_id;", + ) + .bind(partition_id) // $1 + .fetch_one(&mut *tx) + .await?; + + tx.commit().await?; + + Ok(PartitionSnapshot::encode( + namespace_id, + partition, + files, + sc, + generation as _, + )?) + } +} + +#[async_trait] +impl ParquetFileRepo for PostgresTxn { + async fn flag_for_delete_by_retention(&mut self) -> Result> { + let flagged_at = Timestamp::from(self.time_provider.now()); + // TODO - include check of table retention period once implemented + let flagged = sqlx::query( + r#" +WITH parquet_file_ids as ( + SELECT parquet_file.object_store_id + FROM namespace, parquet_file + WHERE namespace.retention_period_ns IS NOT NULL + AND parquet_file.to_delete IS NULL + AND parquet_file.max_time < $1 - namespace.retention_period_ns + AND namespace.id = parquet_file.namespace_id + LIMIT $2 +) +UPDATE parquet_file +SET to_delete = $1 +WHERE object_store_id IN (SELECT object_store_id FROM parquet_file_ids) +RETURNING partition_id, object_store_id; + "#, + ) + .bind(flagged_at) // $1 + .bind(MAX_PARQUET_FILES_SELECTED_ONCE_FOR_RETENTION) // $2 + .fetch_all(&mut self.inner) + .await?; + + let flagged = flagged + .into_iter() + .map(|row| (row.get("partition_id"), row.get("object_store_id"))) + .collect(); + Ok(flagged) + } + + async fn delete_old_ids_only(&mut self, older_than: Timestamp) -> Result> { + // see https://www.crunchydata.com/blog/simulating-update-or-delete-with-limit-in-postgres-ctes-to-the-rescue + let deleted = sqlx::query( + r#" +WITH parquet_file_ids as ( + SELECT object_store_id + FROM parquet_file + WHERE to_delete < $1 + LIMIT $2 +) +DELETE FROM parquet_file +WHERE object_store_id IN (SELECT object_store_id FROM parquet_file_ids) +RETURNING object_store_id; + "#, + ) + .bind(older_than) // $1 + .bind(MAX_PARQUET_FILES_SELECTED_ONCE_FOR_DELETE) // $2 + .fetch_all(&mut self.inner) + .await?; + + let deleted = deleted + .into_iter() + .map(|row| row.get("object_store_id")) + .collect(); + Ok(deleted) + } + + async fn list_by_partition_not_to_delete_batch( + &mut self, + partition_ids: Vec, + ) -> Result> { + sqlx::query_as::<_, ParquetFile>( + r#" +SELECT parquet_file.id, namespace_id, parquet_file.table_id, partition_id, partition_hash_id, + object_store_id, min_time, max_time, parquet_file.to_delete, file_size_bytes, row_count, + compaction_level, created_at, column_set, max_l0_created_at +FROM parquet_file +WHERE parquet_file.partition_id = ANY($1) + AND parquet_file.to_delete IS NULL; + "#, + ) + .bind(partition_ids) // $1 + .fetch_all(&mut self.inner) + .await + .map_err(Error::from) + } + + async fn get_by_object_store_id( + &mut self, + object_store_id: ObjectStoreId, + ) -> Result> { + let rec = sqlx::query_as::<_, ParquetFile>( + r#" +SELECT id, namespace_id, table_id, partition_id, partition_hash_id, object_store_id, min_time, + max_time, to_delete, file_size_bytes, row_count, compaction_level, created_at, column_set, + max_l0_created_at +FROM parquet_file +WHERE object_store_id = $1; + "#, + ) + .bind(object_store_id) // $1 + .fetch_one(&mut self.inner) + .await; + + if let Err(sqlx::Error::RowNotFound) = rec { + return Ok(None); + } + + let parquet_file = rec?; + + Ok(Some(parquet_file)) + } + + async fn exists_by_object_store_id_batch( + &mut self, + object_store_ids: Vec, + ) -> Result> { + sqlx::query( + // sqlx's readme suggests using PG's ANY operator instead of IN; see link below. + // https://github.com/launchbadge/sqlx/blob/main/FAQ.md#how-can-i-do-a-select--where-foo-in--query + r#" +SELECT object_store_id +FROM parquet_file +WHERE object_store_id = ANY($1); + "#, + ) + .bind(object_store_ids) // $1 + .map(|pgr| pgr.get::("object_store_id")) + .fetch_all(&mut self.inner) + .await + .map_err(Error::from) + } + + async fn create_upgrade_delete( + &mut self, + partition_id: PartitionId, + delete: &[ObjectStoreId], + upgrade: &[ObjectStoreId], + create: &[ParquetFileParams], + target_level: CompactionLevel, + ) -> Result> { + let delete_set: HashSet<_> = delete.iter().map(|d| d.get_uuid()).collect(); + let upgrade_set: HashSet<_> = upgrade.iter().map(|u| u.get_uuid()).collect(); + + assert!( + delete_set.is_disjoint(&upgrade_set), + "attempted to upgrade a file scheduled for delete" + ); + + let mut tx = self.inner.pool.begin().await?; + + let marked_at = Timestamp::from(self.time_provider.now()); + flag_for_delete(&mut *tx, partition_id, delete, marked_at).await?; + + update_compaction_level(&mut *tx, partition_id, upgrade, target_level).await?; + + let mut ids = Vec::with_capacity(create.len()); + for file in create { + if file.partition_id != partition_id { + return Err(Error::External { + source: format!("Inconsistent ParquetFileParams, expected PartitionId({partition_id}) got PartitionId({})", file.partition_id).into(), + }); + } + let id = create_parquet_file(&mut *tx, partition_id, file).await?; + ids.push(id); + } + + tx.commit().await?; + + Ok(ids) + } +} + +// The following three functions are helpers to the create_upgrade_delete method. +// They are also used by the respective create/flag_for_delete/update_compaction_level methods. +async fn create_parquet_file<'q, E>( + executor: E, + partition_id: PartitionId, + parquet_file_params: &ParquetFileParams, +) -> Result +where + E: Executor<'q, Database = Postgres>, +{ + let ParquetFileParams { + namespace_id, + table_id, + partition_id: _, + partition_hash_id, + object_store_id, + min_time, + max_time, + file_size_bytes, + row_count, + compaction_level, + created_at, + column_set, + max_l0_created_at, + } = parquet_file_params; + + let query = sqlx::query_scalar::<_, ParquetFileId>( + r#" +INSERT INTO parquet_file ( + table_id, partition_id, partition_hash_id, object_store_id, + min_time, max_time, file_size_bytes, + row_count, compaction_level, created_at, namespace_id, column_set, max_l0_created_at ) +VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13 ) +RETURNING id; + "#, + ) + .bind(table_id) // $1 + .bind(partition_id) // $2 + .bind(partition_hash_id.as_ref()) // $3 + .bind(object_store_id) // $4 + .bind(min_time) // $5 + .bind(max_time) // $6 + .bind(file_size_bytes) // $7 + .bind(row_count) // $8 + .bind(compaction_level) // $9 + .bind(created_at) // $10 + .bind(namespace_id) // $11 + .bind(column_set) // $12 + .bind(max_l0_created_at); // $13 + + let parquet_file_id = query.fetch_one(executor).await.map_err(|e| { + if is_unique_violation(&e) { + Error::AlreadyExists { + descr: object_store_id.to_string(), + } + } else if is_fk_violation(&e) { + Error::NotFound { + descr: e.to_string(), + } + } else { + Error::External { + source: Box::new(e), + } + } + })?; + + Ok(parquet_file_id) +} + +async fn flag_for_delete<'q, E>( + executor: E, + partition_id: PartitionId, + ids: &[ObjectStoreId], + marked_at: Timestamp, +) -> Result<()> +where + E: Executor<'q, Database = Postgres>, +{ + let updated = + sqlx::query_as::<_, (i64,)>(r#"UPDATE parquet_file SET to_delete = $1 WHERE object_store_id = ANY($2) AND partition_id = $3 AND to_delete is NULL RETURNING id;"#) + .bind(marked_at) // $1 + .bind(ids) // $2 + .bind(partition_id) // $3 + .fetch_all(executor) + .await?; + + if updated.len() != ids.len() { + return Err(Error::NotFound { + descr: "parquet file(s) not found for delete".to_string(), + }); + } + + Ok(()) +} + +async fn update_compaction_level<'q, E>( + executor: E, + partition_id: PartitionId, + parquet_file_ids: &[ObjectStoreId], + compaction_level: CompactionLevel, +) -> Result<()> +where + E: Executor<'q, Database = Postgres>, +{ + let updated = sqlx::query_as::<_, (i64,)>( + r#" +UPDATE parquet_file +SET compaction_level = $1 +WHERE object_store_id = ANY($2) AND partition_id = $3 AND to_delete is NULL RETURNING id; + "#, + ) + .bind(compaction_level) // $1 + .bind(parquet_file_ids) // $2 + .bind(partition_id) // $3 + .fetch_all(executor) + .await?; + + if updated.len() != parquet_file_ids.len() { + return Err(Error::NotFound { + descr: "parquet file(s) not found for upgrade".to_string(), + }); + } + + Ok(()) +} + +/// The error code returned by Postgres for a unique constraint violation. +/// +/// See +const PG_UNIQUE_VIOLATION: &str = "23505"; + +/// Returns true if `e` is a unique constraint violation error. +fn is_unique_violation(e: &sqlx::Error) -> bool { + if let sqlx::Error::Database(inner) = e { + if let Some(code) = inner.code() { + if code == PG_UNIQUE_VIOLATION { + return true; + } + } + } + + false +} + +/// Error code returned by Postgres for a foreign key constraint violation. +const PG_FK_VIOLATION: &str = "23503"; + +fn is_fk_violation(e: &sqlx::Error) -> bool { + if let sqlx::Error::Database(inner) = e { + if let Some(code) = inner.code() { + if code == PG_FK_VIOLATION { + return true; + } + } + } + + false +} + +/// Test helpers postgres testing. +#[cfg(test)] +pub(crate) mod test_utils { + use super::*; + use rand::Rng; + use sqlx::migrate::MigrateDatabase; + + pub(crate) const TEST_DSN_ENV: &str = "TEST_INFLUXDB_IOX_CATALOG_DSN"; + + /// Helper macro to skip tests if TEST_INTEGRATION and TEST_INFLUXDB_IOX_CATALOG_DSN environment + /// variables are not set. + macro_rules! maybe_skip_integration { + ($panic_msg:expr) => {{ + dotenvy::dotenv().ok(); + + let required_vars = [crate::postgres::test_utils::TEST_DSN_ENV]; + let unset_vars: Vec<_> = required_vars + .iter() + .filter_map(|&name| match std::env::var(name) { + Ok(_) => None, + Err(_) => Some(name), + }) + .collect(); + let unset_var_names = unset_vars.join(", "); + + let force = std::env::var("TEST_INTEGRATION"); + + if force.is_ok() && !unset_var_names.is_empty() { + panic!( + "TEST_INTEGRATION is set, \ + but variable(s) {} need to be set", + unset_var_names + ); + } else if force.is_err() { + eprintln!( + "skipping Postgres integration test - set {}TEST_INTEGRATION to run", + if unset_var_names.is_empty() { + String::new() + } else { + format!("{} and ", unset_var_names) + } + ); + + let panic_msg: &'static str = $panic_msg; + if !panic_msg.is_empty() { + panic!("{}", panic_msg); + } + + return; + } + }}; + () => { + maybe_skip_integration!("") + }; + } + + pub(crate) use maybe_skip_integration; + + pub(crate) async fn create_db(dsn: &str) { + // Create the catalog database if it doesn't exist + if !Postgres::database_exists(dsn).await.unwrap() { + // Ignore failure if another test has already created the database + let _ = Postgres::create_database(dsn).await; + } + } + + pub(crate) async fn setup_db_no_migration() -> PostgresCatalog { + // create a random schema for this particular pool + let schema_name = { + // use scope to make it clear to clippy / rust that `rng` is + // not carried past await points + let mut rng = rand::thread_rng(); + (&mut rng) + .sample_iter(rand::distributions::Alphanumeric) + .filter(|c| c.is_ascii_alphabetic()) + .take(20) + .map(char::from) + .collect::() + .to_ascii_lowercase() + }; + info!(schema_name, "test schema"); + + let metrics = Arc::new(metric::Registry::default()); + let dsn = std::env::var("TEST_INFLUXDB_IOX_CATALOG_DSN").unwrap(); + + create_db(&dsn).await; + + let options = PostgresConnectionOptions { + app_name: String::from("test"), + schema_name: schema_name.clone(), + dsn, + max_conns: 3, + ..Default::default() + }; + let pg = PostgresCatalog::connect(options, metrics) + .await + .expect("failed to connect catalog"); + + // Create the test schema + pg.pool + .execute(format!("CREATE SCHEMA {schema_name};").as_str()) + .await + .expect("failed to create test schema"); + + // Ensure the test user has permission to interact with the test schema. + pg.pool + .execute( + format!( + "GRANT USAGE ON SCHEMA {schema_name} TO public; GRANT CREATE ON SCHEMA {schema_name} TO public;" + ) + .as_str(), + ) + .await + .expect("failed to grant privileges to schema"); + + pg + } + + pub(crate) async fn setup_db() -> PostgresCatalog { + let pg = setup_db_no_migration().await; + // Run the migrations against this random schema. + pg.setup().await.expect("failed to initialise database"); + pg + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::interface::ParquetFileRepoExt; + use crate::{ + postgres::test_utils::{ + create_db, maybe_skip_integration, setup_db, setup_db_no_migration, + }, + test_helpers::{arbitrary_namespace, arbitrary_parquet_file_params, arbitrary_table}, + }; + use assert_matches::assert_matches; + use data_types::partition_template::TemplatePart; + use generated_types::influxdata::iox::partition_template::v1 as proto; + use metric::{Observation, RawReporter}; + use std::{io::Write, ops::Deref, sync::Arc, time::Instant}; + use tempfile::NamedTempFile; + use test_helpers::maybe_start_logging; + + /// Small no-op test just to print out the migrations. + /// + /// This is helpful to look up migration checksums and debug parsing of the migration files. + #[test] + fn print_migrations() { + println!("{:#?}", MIGRATOR.deref()); + } + + #[tokio::test] + async fn test_migration() { + maybe_skip_integration!(); + maybe_start_logging(); + + let postgres = setup_db_no_migration().await; + + // 1st setup + postgres.setup().await.unwrap(); + + // 2nd setup + postgres.setup().await.unwrap(); + } + + #[tokio::test] + async fn test_migration_generic() { + use crate::migrate::test_utils::test_migration; + + maybe_skip_integration!(); + maybe_start_logging(); + + test_migration(&MIGRATOR, || async { + setup_db_no_migration().await.into_pool() + }) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_catalog() { + maybe_skip_integration!(); + + let postgres = setup_db().await; + + // Validate the connection time zone is the expected UTC value. + let tz: String = sqlx::query_scalar("SHOW TIME ZONE;") + .fetch_one(&postgres.pool) + .await + .expect("read time zone"); + assert_eq!(tz, "UTC"); + + let pool = postgres.pool.clone(); + let schema_name = postgres.schema_name().to_string(); + + let postgres: Arc = Arc::new(postgres); + + crate::interface_tests::test_catalog(|| async { + // Clean the schema. + pool + .execute(format!("DROP SCHEMA {schema_name} CASCADE").as_str()) + .await + .expect("failed to clean schema between tests"); + + // Recreate the test schema + pool + .execute(format!("CREATE SCHEMA {schema_name};").as_str()) + .await + .expect("failed to create test schema"); + + // Ensure the test user has permission to interact with the test schema. + pool + .execute( + format!( + "GRANT USAGE ON SCHEMA {schema_name} TO public; GRANT CREATE ON SCHEMA {schema_name} TO public;" + ) + .as_str(), + ) + .await + .expect("failed to grant privileges to schema"); + + // Run the migrations against this random schema. + postgres.setup().await.expect("failed to initialise database"); + + Arc::clone(&postgres) + }) + .await; + } + + #[tokio::test] + async fn existing_partitions_without_hash_id() { + maybe_skip_integration!(); + + let postgres = setup_db().await; + let pool = postgres.pool.clone(); + let postgres: Arc = Arc::new(postgres); + let mut repos = postgres.repositories(); + + let namespace = arbitrary_namespace(&mut *repos, "ns4").await; + let table = arbitrary_table(&mut *repos, "table", &namespace).await; + let table_id = table.id; + let key = PartitionKey::from("francis-scott-key-key"); + + // Create a partition record in the database that has `NULL` for its `hash_id` + // value, which is what records existing before the migration adding that column will have. + sqlx::query( + r#" +INSERT INTO partition + (partition_key, table_id, sort_key_ids) +VALUES + ( $1, $2, '{}') +ON CONFLICT ON CONSTRAINT partition_key_unique +DO UPDATE SET partition_key = partition.partition_key +RETURNING id, hash_id, table_id, partition_key, sort_key_ids, new_file_at; + "#, + ) + .bind(&key) // $1 + .bind(table_id) // $2 + .fetch_one(&pool) + .await + .unwrap(); + + // Check that the hash_id being null in the database doesn't break querying for partitions. + let table_partitions = repos.partitions().list_by_table_id(table_id).await.unwrap(); + assert_eq!(table_partitions.len(), 1); + let partition = &table_partitions[0]; + assert!(partition.hash_id().is_none()); + + // Call create_or_get for the same (key, table_id) pair, to ensure the write is idempotent + // and that the hash_id still doesn't get set. + let inserted_again = repos + .partitions() + .create_or_get(key, table_id) + .await + .expect("idempotent write should succeed"); + + // Test: sort_key_ids from freshly insert with empty value + assert!(inserted_again.sort_key_ids().is_none()); + + assert_eq!(partition, &inserted_again); + + // Create a Parquet file record in this partition to ensure we don't break new data + // ingestion for old-style partitions + let parquet_file_params = arbitrary_parquet_file_params(&namespace, &table, partition); + let parquet_file = repos + .parquet_files() + .create(parquet_file_params) + .await + .unwrap(); + assert_eq!(parquet_file.partition_hash_id, None); + + // Add a partition record WITH a hash ID + repos + .partitions() + .create_or_get(PartitionKey::from("Something else"), table_id) + .await + .unwrap(); + + // Ensure we can list only the old-style partitions + let old_style_partitions = repos.partitions().list_old_style().await.unwrap(); + assert_eq!(old_style_partitions.len(), 1); + assert_eq!(old_style_partitions[0].id, partition.id); + } + + #[test] + fn test_parse_dsn_file() { + assert_eq!( + get_dsn_file_path("dsn-file:///tmp/my foo.txt"), + Some("/tmp/my foo.txt".to_owned()), + ); + assert_eq!(get_dsn_file_path("dsn-file:blah"), None,); + assert_eq!(get_dsn_file_path("postgres://user:pw@host/db"), None,); + } + + #[tokio::test] + async fn test_reload() { + maybe_skip_integration!(); + + const POLLING_INTERVAL: Duration = Duration::from_millis(10); + + // fetch dsn from envvar + let test_dsn = std::env::var("TEST_INFLUXDB_IOX_CATALOG_DSN").unwrap(); + create_db(&test_dsn).await; + eprintln!("TEST_DSN={test_dsn}"); + + // create a temp file to store the initial dsn + let mut dsn_file = NamedTempFile::new().expect("create temp file"); + dsn_file + .write_all(test_dsn.as_bytes()) + .expect("write temp file"); + + const TEST_APPLICATION_NAME: &str = "test_application_name"; + let dsn_good = format!("dsn-file://{}", dsn_file.path().display()); + eprintln!("dsn_good={dsn_good}"); + + // create a hot swap pool with test application name and dsn file pointing to tmp file. + // we will later update this file and the pool should be replaced. + let options = PostgresConnectionOptions { + app_name: TEST_APPLICATION_NAME.to_owned(), + schema_name: String::from("test"), + dsn: dsn_good, + max_conns: 3, + hotswap_poll_interval: POLLING_INTERVAL, + ..Default::default() + }; + let metrics = Arc::new(metric::Registry::new()); + let pool = new_pool(&options, metrics).await.expect("connect"); + eprintln!("got a pool"); + + // ensure the application name is set as expected + let application_name: String = + sqlx::query_scalar("SELECT current_setting('application_name') as application_name;") + .fetch_one(&pool) + .await + .expect("read application_name"); + assert_eq!(application_name, TEST_APPLICATION_NAME); + + // create a new temp file object with updated dsn and overwrite the previous tmp file + const TEST_APPLICATION_NAME_NEW: &str = "changed_application_name"; + let mut new_dsn_file = NamedTempFile::new().expect("create temp file"); + new_dsn_file + .write_all(test_dsn.as_bytes()) + .expect("write temp file"); + new_dsn_file + .write_all(format!("?application_name={TEST_APPLICATION_NAME_NEW}").as_bytes()) + .expect("write temp file"); + new_dsn_file + .persist(dsn_file.path()) + .expect("overwrite new dsn file"); + + // wait until the hotswap machinery has reloaded the updated DSN file and + // successfully performed a new connection with the new DSN. + let mut application_name = "".to_string(); + let start = Instant::now(); + while start.elapsed() < Duration::from_secs(5) + && application_name != TEST_APPLICATION_NAME_NEW + { + tokio::time::sleep(POLLING_INTERVAL).await; + + application_name = sqlx::query_scalar( + "SELECT current_setting('application_name') as application_name;", + ) + .fetch_one(&pool) + .await + .expect("read application_name"); + } + assert_eq!(application_name, TEST_APPLICATION_NAME_NEW); + } + + #[tokio::test] + async fn test_billing_summary_on_parqet_file_creation() { + maybe_skip_integration!(); + + let postgres = setup_db().await; + let pool = postgres.pool.clone(); + let postgres: Arc = Arc::new(postgres); + let mut repos = postgres.repositories(); + let namespace = arbitrary_namespace(&mut *repos, "ns4").await; + let table = arbitrary_table(&mut *repos, "table", &namespace).await; + let key = "bananas"; + let partition = repos + .partitions() + .create_or_get(key.into(), table.id) + .await + .unwrap(); + + // parquet file to create- all we care about here is the size + let mut p1 = arbitrary_parquet_file_params(&namespace, &table, &partition); + p1.file_size_bytes = 1337; + let f1 = repos.parquet_files().create(p1.clone()).await.unwrap(); + // insert the same again with a different size; we should then have 3x1337 as total file + // size + p1.object_store_id = ObjectStoreId::new(); + p1.file_size_bytes *= 2; + let _f2 = repos + .parquet_files() + .create(p1.clone()) + .await + .expect("create parquet file should succeed"); + + // after adding two files we should have 3x1337 in the summary + let total_file_size_bytes: i64 = + sqlx::query_scalar("SELECT total_file_size_bytes FROM billing_summary;") + .fetch_one(&pool) + .await + .expect("fetch total file size failed"); + assert_eq!(total_file_size_bytes, 1337 * 3); + + // flag f1 for deletion and assert that the total file size is reduced accordingly. + repos + .parquet_files() + .create_upgrade_delete( + partition.id, + &[f1.object_store_id], + &[], + &[], + CompactionLevel::Initial, + ) + .await + .expect("flag parquet file for deletion should succeed"); + let total_file_size_bytes: i64 = + sqlx::query_scalar("SELECT total_file_size_bytes FROM billing_summary;") + .fetch_one(&pool) + .await + .expect("fetch total file size failed"); + // we marked the first file of size 1337 for deletion leaving only the second that was 2x + // that + assert_eq!(total_file_size_bytes, 1337 * 2); + + // actually deleting shouldn't change the total + let older_than = p1.created_at + 1; + repos + .parquet_files() + .delete_old_ids_only(older_than) + .await + .expect("parquet file deletion should succeed"); + let total_file_size_bytes: i64 = + sqlx::query_scalar("SELECT total_file_size_bytes FROM billing_summary;") + .fetch_one(&pool) + .await + .expect("fetch total file size failed"); + assert_eq!(total_file_size_bytes, 1337 * 2); + } + + #[tokio::test] + async fn namespace_partition_template_null_is_the_default_in_the_database() { + maybe_skip_integration!(); + + let postgres = setup_db().await; + let pool = postgres.pool.clone(); + let postgres: Arc = Arc::new(postgres); + let mut repos = postgres.repositories(); + + let namespace_name = "apples"; + + // Create a namespace record in the database that has `NULL` for its `partition_template` + // value, which is what records existing before the migration adding that column will have. + let insert_null_partition_template_namespace = sqlx::query( + r#" +INSERT INTO namespace ( + name, retention_period_ns, partition_template +) +VALUES ( $1, $2, NULL ) +RETURNING id, name, retention_period_ns, max_tables, max_columns_per_table, deleted_at, + partition_template; + "#, + ) + .bind(namespace_name) // $1 + .bind(None::>); // $2 + + insert_null_partition_template_namespace + .fetch_one(&pool) + .await + .unwrap(); + + let lookup_namespace = repos + .namespaces() + .get_by_name(namespace_name, SoftDeletedRows::ExcludeDeleted) + .await + .unwrap() + .unwrap(); + // When fetching this namespace from the database, the `FromRow` impl should set its + // `partition_template` to the default. + assert_eq!( + lookup_namespace.partition_template, + NamespacePartitionTemplateOverride::default() + ); + + // When creating a namespace through the catalog functions without specifying a custom + // partition template, + let created_without_custom_template = repos + .namespaces() + .create( + &"lemons".try_into().unwrap(), + None, // no partition template + None, + None, + ) + .await + .unwrap(); + + // it should have the default template in the application, + assert_eq!( + created_without_custom_template.partition_template, + NamespacePartitionTemplateOverride::default() + ); + + // and store NULL in the database record. + let record = sqlx::query("SELECT name, partition_template FROM namespace WHERE id = $1;") + .bind(created_without_custom_template.id) + .fetch_one(&pool) + .await + .unwrap(); + let name: String = record.try_get("name").unwrap(); + assert_eq!(created_without_custom_template.name, name); + let partition_template: Option = + record.try_get("partition_template").unwrap(); + assert!(partition_template.is_none()); + + // When explicitly setting a template that happens to be equal to the application default, + // assume it's important that it's being specially requested and store it rather than NULL. + let namespace_custom_template_name = "kumquats"; + let custom_partition_template_equal_to_default = + NamespacePartitionTemplateOverride::try_from(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat( + "%Y-%m-%d".to_owned(), + )), + }], + }) + .unwrap(); + let namespace_custom_template = repos + .namespaces() + .create( + &namespace_custom_template_name.try_into().unwrap(), + Some(custom_partition_template_equal_to_default.clone()), + None, + None, + ) + .await + .unwrap(); + assert_eq!( + namespace_custom_template.partition_template, + custom_partition_template_equal_to_default + ); + let record = sqlx::query("SELECT name, partition_template FROM namespace WHERE id = $1;") + .bind(namespace_custom_template.id) + .fetch_one(&pool) + .await + .unwrap(); + let name: String = record.try_get("name").unwrap(); + assert_eq!(namespace_custom_template.name, name); + let partition_template: Option = + record.try_get("partition_template").unwrap(); + assert_eq!( + partition_template.unwrap(), + custom_partition_template_equal_to_default + ); + } + + #[tokio::test] + async fn table_partition_template_null_is_the_default_in_the_database() { + maybe_skip_integration!(); + + let postgres = setup_db().await; + let pool = postgres.pool.clone(); + let postgres: Arc = Arc::new(postgres); + let mut repos = postgres.repositories(); + + let namespace_default_template_name = "oranges"; + let namespace_default_template = repos + .namespaces() + .create( + &namespace_default_template_name.try_into().unwrap(), + None, // no partition template + None, + None, + ) + .await + .unwrap(); + + let namespace_custom_template_name = "limes"; + let namespace_custom_template = repos + .namespaces() + .create( + &namespace_custom_template_name.try_into().unwrap(), + Some( + NamespacePartitionTemplateOverride::try_from(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat("year-%Y".into())), + }], + }) + .unwrap(), + ), + None, + None, + ) + .await + .unwrap(); + + // In a namespace that also has a NULL template, create a table record in the database that + // has `NULL` for its `partition_template` value, which is what records existing before the + // migration adding that column will have. + let table_name = "null_template"; + let insert_null_partition_template_table = sqlx::query( + r#" +INSERT INTO table_name ( name, namespace_id, partition_template ) +VALUES ( $1, $2, NULL ) +RETURNING *; + "#, + ) + .bind(table_name) // $1 + .bind(namespace_default_template.id); // $2 + + insert_null_partition_template_table + .fetch_one(&pool) + .await + .unwrap(); + + let lookup_table = repos + .tables() + .get_by_namespace_and_name(namespace_default_template.id, table_name) + .await + .unwrap() + .unwrap(); + // When fetching this table from the database, the `FromRow` impl should set its + // `partition_template` to the system default (because the namespace didn't have a template + // either). + assert_eq!( + lookup_table.partition_template, + TablePartitionTemplateOverride::default() + ); + + // In a namespace that has a custom template, create a table record in the database that + // has `NULL` for its `partition_template` value. + // + // THIS ACTUALLY SHOULD BE IMPOSSIBLE because: + // + // * Namespaces have to exist before tables + // * `partition_tables` are immutable on both namespaces and tables + // * When the migration adding the `partition_table` column is deployed, namespaces can + // begin to be created with `partition_templates` + // * *Then* tables can be created with `partition_templates` or not + // * When tables don't get a custom table partition template but their namespace has one, + // their database record will get the namespace partition template. + // + // In other words, table `partition_template` values in the database is allowed to possibly + // be `NULL` IFF their namespace's `partition_template` is `NULL`. + // + // That said, this test creates this hopefully-impossible scenario to ensure that the + // defined, expected behavior if a table record somehow exists in the database with a `NULL` + // `partition_template` value is that it will have the application default partition + // template *even if the namespace `partition_template` is not null*. + let table_name = "null_template"; + let insert_null_partition_template_table = sqlx::query( + r#" +INSERT INTO table_name ( name, namespace_id, partition_template ) +VALUES ( $1, $2, NULL ) +RETURNING *; + "#, + ) + .bind(table_name) // $1 + .bind(namespace_custom_template.id); // $2 + + insert_null_partition_template_table + .fetch_one(&pool) + .await + .unwrap(); + + let lookup_table = repos + .tables() + .get_by_namespace_and_name(namespace_custom_template.id, table_name) + .await + .unwrap() + .unwrap(); + // When fetching this table from the database, the `FromRow` impl should set its + // `partition_template` to the system default *even though the namespace has a + // template*, because this should be impossible as detailed above. + assert_eq!( + lookup_table.partition_template, + TablePartitionTemplateOverride::default() + ); + + // # Table template false, namespace template true + // + // When creating a table through the catalog functions *without* a custom table template in + // a namespace *with* a custom partition template, + let table_no_template_with_namespace_template = repos + .tables() + .create( + "pomelo", + TablePartitionTemplateOverride::try_new( + None, // no custom partition template + &namespace_custom_template.partition_template, + ) + .unwrap(), + namespace_custom_template.id, + ) + .await + .unwrap(); + + // it should have the namespace's template + assert_eq!( + table_no_template_with_namespace_template.partition_template, + TablePartitionTemplateOverride::try_new( + None, + &namespace_custom_template.partition_template + ) + .unwrap() + ); + + // and store that value in the database record. + let record = sqlx::query("SELECT name, partition_template FROM table_name WHERE id = $1;") + .bind(table_no_template_with_namespace_template.id) + .fetch_one(&pool) + .await + .unwrap(); + let name: String = record.try_get("name").unwrap(); + assert_eq!(table_no_template_with_namespace_template.name, name); + let partition_template: Option = + record.try_get("partition_template").unwrap(); + assert_eq!( + partition_template.unwrap(), + TablePartitionTemplateOverride::try_new( + None, + &namespace_custom_template.partition_template + ) + .unwrap() + ); + + // # Table template true, namespace template false + // + // When creating a table through the catalog functions *with* a custom table template in + // a namespace *without* a custom partition template, + let custom_table_template = proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("chemical".into())), + }], + }; + let table_with_template_no_namespace_template = repos + .tables() + .create( + "tangerine", + TablePartitionTemplateOverride::try_new( + Some(custom_table_template), // with custom partition template + &namespace_default_template.partition_template, + ) + .unwrap(), + namespace_default_template.id, + ) + .await + .unwrap(); + + // it should have the custom table template + let table_template_parts: Vec<_> = table_with_template_no_namespace_template + .partition_template + .parts() + .collect(); + assert_eq!(table_template_parts.len(), 1); + assert_matches!( + table_template_parts[0], + TemplatePart::TagValue(tag) if tag == "chemical" + ); + + // and store that value in the database record. + let record = sqlx::query("SELECT name, partition_template FROM table_name WHERE id = $1;") + .bind(table_with_template_no_namespace_template.id) + .fetch_one(&pool) + .await + .unwrap(); + let name: String = record.try_get("name").unwrap(); + assert_eq!(table_with_template_no_namespace_template.name, name); + let partition_template = record + .try_get::, _>("partition_template") + .unwrap() + .unwrap(); + let table_template_parts: Vec<_> = partition_template.parts().collect(); + assert_eq!(table_template_parts.len(), 1); + assert_matches!( + table_template_parts[0], + TemplatePart::TagValue(tag) if tag == "chemical" + ); + + // # Table template true, namespace template true + // + // When creating a table through the catalog functions *with* a custom table template in + // a namespace *with* a custom partition template, + let custom_table_template = proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("vegetable".into())), + }], + }; + let table_with_template_with_namespace_template = repos + .tables() + .create( + "nectarine", + TablePartitionTemplateOverride::try_new( + Some(custom_table_template), // with custom partition template + &namespace_custom_template.partition_template, + ) + .unwrap(), + namespace_custom_template.id, + ) + .await + .unwrap(); + + // it should have the custom table template + let table_template_parts: Vec<_> = table_with_template_with_namespace_template + .partition_template + .parts() + .collect(); + assert_eq!(table_template_parts.len(), 1); + assert_matches!( + table_template_parts[0], + TemplatePart::TagValue(tag) if tag == "vegetable" + ); + + // and store that value in the database record. + let record = sqlx::query("SELECT name, partition_template FROM table_name WHERE id = $1;") + .bind(table_with_template_with_namespace_template.id) + .fetch_one(&pool) + .await + .unwrap(); + let name: String = record.try_get("name").unwrap(); + assert_eq!(table_with_template_with_namespace_template.name, name); + let partition_template = record + .try_get::, _>("partition_template") + .unwrap() + .unwrap(); + let table_template_parts: Vec<_> = partition_template.parts().collect(); + assert_eq!(table_template_parts.len(), 1); + assert_matches!( + table_template_parts[0], + TemplatePart::TagValue(tag) if tag == "vegetable" + ); + + // # Table template false, namespace template false + // + // When creating a table through the catalog functions *without* a custom table template in + // a namespace *without* a custom partition template, + let table_no_template_no_namespace_template = repos + .tables() + .create( + "grapefruit", + TablePartitionTemplateOverride::try_new( + None, // no custom partition template + &namespace_default_template.partition_template, + ) + .unwrap(), + namespace_default_template.id, + ) + .await + .unwrap(); + + // it should have the default template in the application, + assert_eq!( + table_no_template_no_namespace_template.partition_template, + TablePartitionTemplateOverride::default() + ); + + // and store NULL in the database record. + let record = sqlx::query("SELECT name, partition_template FROM table_name WHERE id = $1;") + .bind(table_no_template_no_namespace_template.id) + .fetch_one(&pool) + .await + .unwrap(); + let name: String = record.try_get("name").unwrap(); + assert_eq!(table_no_template_no_namespace_template.name, name); + let partition_template: Option = + record.try_get("partition_template").unwrap(); + assert!(partition_template.is_none()); + } + + #[tokio::test] + async fn test_metrics() { + maybe_skip_integration!(); + + let postgres = setup_db_no_migration().await; + + let mut reporter = RawReporter::default(); + postgres.metrics.report(&mut reporter); + assert_eq!( + reporter + .metric("sqlx_postgres_connections") + .unwrap() + .observation(&[("pool_id", "0"), ("state", "min")]) + .unwrap(), + &Observation::U64Gauge(1), + ); + assert_eq!( + reporter + .metric("sqlx_postgres_connections") + .unwrap() + .observation(&[("pool_id", "0"), ("state", "max")]) + .unwrap(), + &Observation::U64Gauge(3), + ); + } +} diff --git a/iox_catalog/src/sqlite.rs b/iox_catalog/src/sqlite.rs new file mode 100644 index 0000000..e91cde3 --- /dev/null +++ b/iox_catalog/src/sqlite.rs @@ -0,0 +1,2196 @@ +//! A SQLite backed implementation of the Catalog + +use crate::interface::PartitionRepoExt; +use crate::{ + constants::{ + MAX_PARQUET_FILES_SELECTED_ONCE_FOR_DELETE, MAX_PARQUET_FILES_SELECTED_ONCE_FOR_RETENTION, + }, + interface::{ + AlreadyExistsSnafu, CasFailure, Catalog, ColumnRepo, Error, NamespaceRepo, ParquetFileRepo, + PartitionRepo, RepoCollection, Result, SoftDeletedRows, TableRepo, + }, + metrics::MetricDecorator, +}; +use async_trait::async_trait; +use data_types::snapshot::partition::PartitionSnapshot; +use data_types::snapshot::table::TableSnapshot; +use data_types::{ + partition_template::{ + NamespacePartitionTemplateOverride, TablePartitionTemplateOverride, TemplatePart, + }, + Column, ColumnId, ColumnSet, ColumnType, CompactionLevel, MaxColumnsPerTable, MaxTables, + Namespace, NamespaceId, NamespaceName, NamespaceServiceProtectionLimitsOverride, ObjectStoreId, + ParquetFile, ParquetFileId, ParquetFileParams, Partition, PartitionHashId, PartitionId, + PartitionKey, SkippedCompaction, SortKeyIds, Table, TableId, Timestamp, +}; +use iox_time::{SystemProvider, TimeProvider}; +use metric::Registry; +use observability_deps::tracing::debug; +use parking_lot::Mutex; +use serde::{Deserialize, Serialize}; +use snafu::prelude::*; +use sqlx::{ + migrate::Migrator, + sqlite::{SqliteConnectOptions, SqliteRow}, + types::Json, + Executor, FromRow, Pool, Row, Sqlite, SqlitePool, +}; +use std::{ + collections::{HashMap, HashSet}, + fmt::Display, + str::FromStr, + sync::Arc, +}; + +static MIGRATOR: Migrator = sqlx::migrate!("sqlite/migrations"); + +/// SQLite connection options. +#[derive(Debug, Clone)] +pub struct SqliteConnectionOptions { + /// local file path to .sqlite file + pub file_path: String, +} + +/// SQLite catalog. +#[derive(Debug)] +pub struct SqliteCatalog { + metrics: Arc, + pool: Pool, + time_provider: Arc, + options: SqliteConnectionOptions, +} + +/// transaction for [`SqliteCatalog`]. +#[derive(Debug)] +pub struct SqliteTxn { + inner: Mutex, + time_provider: Arc, +} + +#[derive(Debug)] +struct SqliteTxnInner { + pool: Pool, +} + +impl<'c> Executor<'c> for &'c mut SqliteTxnInner { + type Database = Sqlite; + + #[allow(clippy::type_complexity)] + fn fetch_many<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> futures::stream::BoxStream< + 'e, + Result< + sqlx::Either< + ::QueryResult, + ::Row, + >, + sqlx::Error, + >, + > + where + 'c: 'e, + E: sqlx::Execute<'q, Self::Database>, + { + self.pool.fetch_many(query) + } + + fn fetch_optional<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> futures::future::BoxFuture< + 'e, + Result::Row>, sqlx::Error>, + > + where + 'c: 'e, + E: sqlx::Execute<'q, Self::Database>, + { + self.pool.fetch_optional(query) + } + + fn prepare_with<'e, 'q: 'e>( + self, + sql: &'q str, + parameters: &'e [::TypeInfo], + ) -> futures::future::BoxFuture< + 'e, + Result<>::Statement, sqlx::Error>, + > + where + 'c: 'e, + { + self.pool.prepare_with(sql, parameters) + } + + fn describe<'e, 'q: 'e>( + self, + sql: &'q str, + ) -> futures::future::BoxFuture<'e, Result, sqlx::Error>> + where + 'c: 'e, + { + self.pool.describe(sql) + } +} + +impl SqliteCatalog { + /// Connect to the catalog store. + pub async fn connect(options: SqliteConnectionOptions, metrics: Arc) -> Result { + let opts = SqliteConnectOptions::from_str(&options.file_path)?.create_if_missing(true); + + let pool = SqlitePool::connect_with(opts).await?; + Ok(Self { + metrics, + pool, + time_provider: Arc::new(SystemProvider::new()), + options, + }) + } +} + +impl Display for SqliteCatalog { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Sqlite(dsn='{}')", self.options.file_path) + } +} + +#[async_trait] +impl Catalog for SqliteCatalog { + async fn setup(&self) -> Result<()> { + MIGRATOR.run(&self.pool).await?; + + Ok(()) + } + + fn repositories(&self) -> Box { + Box::new(MetricDecorator::new( + SqliteTxn { + inner: Mutex::new(SqliteTxnInner { + pool: self.pool.clone(), + }), + time_provider: Arc::clone(&self.time_provider), + }, + Arc::clone(&self.metrics), + Arc::clone(&self.time_provider), + )) + } + + #[cfg(test)] + fn metrics(&self) -> Arc { + Arc::clone(&self.metrics) + } + + fn time_provider(&self) -> Arc { + Arc::clone(&self.time_provider) + } +} + +impl RepoCollection for SqliteTxn { + fn namespaces(&mut self) -> &mut dyn NamespaceRepo { + self + } + + fn tables(&mut self) -> &mut dyn TableRepo { + self + } + + fn columns(&mut self) -> &mut dyn ColumnRepo { + self + } + + fn partitions(&mut self) -> &mut dyn PartitionRepo { + self + } + + fn parquet_files(&mut self) -> &mut dyn ParquetFileRepo { + self + } +} + +#[async_trait] +impl NamespaceRepo for SqliteTxn { + async fn create( + &mut self, + name: &NamespaceName<'_>, + partition_template: Option, + retention_period_ns: Option, + service_protection_limits: Option, + ) -> Result { + let max_tables = service_protection_limits + .and_then(|l| l.max_tables) + .unwrap_or_default(); + let max_columns_per_table = service_protection_limits + .and_then(|l| l.max_columns_per_table) + .unwrap_or_default(); + + let rec = sqlx::query_as::<_, Namespace>( + r#" +INSERT INTO namespace ( name, retention_period_ns, max_tables, max_columns_per_table, partition_template ) +VALUES ( $1, $2, $3, $4, $5 ) +RETURNING id, name, retention_period_ns, max_tables, max_columns_per_table, deleted_at, + partition_template; + "#, + ) + .bind(name.as_str()) // $1 + .bind(retention_period_ns) // $2 + .bind(max_tables) // $3 + .bind(max_columns_per_table) // $4 + .bind(partition_template); // $5 + + let rec = rec.fetch_one(self.inner.get_mut()).await.map_err(|e| { + if is_unique_violation(&e) { + Error::AlreadyExists { + descr: name.to_string(), + } + } else if is_fk_violation(&e) { + Error::NotFound { + descr: e.to_string(), + } + } else { + Error::External { + source: Box::new(e), + } + } + })?; + + Ok(rec) + } + + async fn list(&mut self, deleted: SoftDeletedRows) -> Result> { + let rec = sqlx::query_as::<_, Namespace>( + format!( + r#" +SELECT id, name, retention_period_ns, max_tables, max_columns_per_table, deleted_at, + partition_template +FROM namespace +WHERE {v}; + "#, + v = deleted.as_sql_predicate() + ) + .as_str(), + ) + .fetch_all(self.inner.get_mut()) + .await?; + + Ok(rec) + } + + async fn get_by_id( + &mut self, + id: NamespaceId, + deleted: SoftDeletedRows, + ) -> Result> { + let rec = sqlx::query_as::<_, Namespace>( + format!( + r#" +SELECT id, name, retention_period_ns, max_tables, max_columns_per_table, deleted_at, + partition_template +FROM namespace +WHERE id=$1 AND {v}; + "#, + v = deleted.as_sql_predicate() + ) + .as_str(), + ) + .bind(id) // $1 + .fetch_one(self.inner.get_mut()) + .await; + + if let Err(sqlx::Error::RowNotFound) = rec { + return Ok(None); + } + + let namespace = rec?; + + Ok(Some(namespace)) + } + + async fn get_by_name( + &mut self, + name: &str, + deleted: SoftDeletedRows, + ) -> Result> { + let rec = sqlx::query_as::<_, Namespace>( + format!( + r#" +SELECT id, name, retention_period_ns, max_tables, max_columns_per_table, deleted_at, + partition_template +FROM namespace +WHERE name=$1 AND {v}; + "#, + v = deleted.as_sql_predicate() + ) + .as_str(), + ) + .bind(name) // $1 + .fetch_one(self.inner.get_mut()) + .await; + + if let Err(sqlx::Error::RowNotFound) = rec { + return Ok(None); + } + + let namespace = rec?; + + Ok(Some(namespace)) + } + + async fn soft_delete(&mut self, name: &str) -> Result<()> { + let flagged_at = Timestamp::from(self.time_provider.now()); + + // note that there is a uniqueness constraint on the name column in the DB + sqlx::query(r#"UPDATE namespace SET deleted_at=$1 WHERE name = $2;"#) + .bind(flagged_at) // $1 + .bind(name) // $2 + .execute(self.inner.get_mut()) + .await + .map_err(Error::from) + .map(|_| ()) + } + + async fn update_table_limit(&mut self, name: &str, new_max: MaxTables) -> Result { + let rec = sqlx::query_as::<_, Namespace>( + r#" +UPDATE namespace +SET max_tables = $1 +WHERE name = $2 +RETURNING id, name, retention_period_ns, max_tables, max_columns_per_table, deleted_at, + partition_template; + "#, + ) + .bind(new_max) + .bind(name) + .fetch_one(self.inner.get_mut()) + .await; + + let namespace = rec.map_err(|e| match e { + sqlx::Error::RowNotFound => Error::NotFound { + descr: name.to_string(), + }, + _ => Error::External { + source: Box::new(e), + }, + })?; + + Ok(namespace) + } + + async fn update_column_limit( + &mut self, + name: &str, + new_max: MaxColumnsPerTable, + ) -> Result { + let rec = sqlx::query_as::<_, Namespace>( + r#" +UPDATE namespace +SET max_columns_per_table = $1 +WHERE name = $2 +RETURNING id, name, retention_period_ns, max_tables, max_columns_per_table, deleted_at, + partition_template; + "#, + ) + .bind(new_max) + .bind(name) + .fetch_one(self.inner.get_mut()) + .await; + + let namespace = rec.map_err(|e| match e { + sqlx::Error::RowNotFound => Error::NotFound { + descr: name.to_string(), + }, + _ => Error::External { + source: Box::new(e), + }, + })?; + + Ok(namespace) + } + + async fn update_retention_period( + &mut self, + name: &str, + retention_period_ns: Option, + ) -> Result { + let rec = sqlx::query_as::<_, Namespace>( + r#" +UPDATE namespace +SET retention_period_ns = $1 +WHERE name = $2 +RETURNING id, name, retention_period_ns, max_tables, max_columns_per_table, deleted_at, + partition_template; + "#, + ) + .bind(retention_period_ns) // $1 + .bind(name) // $2 + .fetch_one(self.inner.get_mut()) + .await; + + let namespace = rec.map_err(|e| match e { + sqlx::Error::RowNotFound => Error::NotFound { + descr: name.to_string(), + }, + _ => Error::External { + source: Box::new(e), + }, + })?; + + Ok(namespace) + } +} + +/// [`TableRepo::create`] needs the ability to create some columns within the same transaction as +/// the table creation. Column creation might also happen through [`ColumnRepo::create_or_get`], +/// which doesn't need to be within an outer transaction. This function was extracted so that these +/// two functions can share code but pass in either the transaction or the regular database +/// connection as the query executor. +async fn insert_column_with_connection<'q, E>( + executor: E, + name: &str, + table_id: TableId, + column_type: ColumnType, +) -> Result +where + E: Executor<'q, Database = Sqlite>, +{ + let rec = sqlx::query_as::<_, Column>( + r#" +INSERT INTO column_name ( name, table_id, column_type ) +SELECT $1, table_id, $3 FROM ( + SELECT max_columns_per_table, namespace.id, table_name.id as table_id, COUNT(column_name.id) AS count + FROM namespace LEFT JOIN table_name ON namespace.id = table_name.namespace_id + LEFT JOIN column_name ON table_name.id = column_name.table_id + WHERE table_name.id = $2 + GROUP BY namespace.max_columns_per_table, namespace.id, table_name.id +) AS get_count WHERE count < max_columns_per_table +ON CONFLICT (table_id, name) +DO UPDATE SET name = column_name.name +RETURNING *; + "#, + ) + .bind(name) // $1 + .bind(table_id) // $2 + .bind(column_type) // $3 + .fetch_one(executor) + .await + .map_err(|e| match e { + sqlx::Error::RowNotFound => Error::LimitExceeded { + descr: format!("couldn't create column {} in table {}; limit reached on namespace", name, table_id) + }, + _ => { + if is_fk_violation(&e) { + Error::NotFound { descr: e.to_string() } + } else { + Error::External { source: Box::new(e) } + } + }})?; + + ensure!( + rec.column_type == column_type, + AlreadyExistsSnafu { + descr: format!( + "column {} is type {} but schema update has type {}", + name, rec.column_type, column_type + ), + } + ); + + Ok(rec) +} + +#[async_trait] +impl TableRepo for SqliteTxn { + async fn create( + &mut self, + name: &str, + partition_template: TablePartitionTemplateOverride, + namespace_id: NamespaceId, + ) -> Result
{ + let mut tx = self.inner.get_mut().pool.begin().await?; + + // A simple insert statement becomes quite complicated in order to avoid checking the table + // limits in a select and then conditionally inserting (which would be racey). + // + // from https://www.postgresql.org/docs/current/sql-insert.html + // "INSERT inserts new rows into a table. One can insert one or more rows specified by + // value expressions, or zero or more rows resulting from a query." + // By using SELECT rather than VALUES it will insert zero rows if it finds a null in the + // subquery, i.e. if count >= max_tables. fetch_one() will return a RowNotFound error if + // nothing was inserted. Not pretty! + let table = sqlx::query_as::<_, Table>( + r#" +INSERT INTO table_name ( name, namespace_id, partition_template ) +SELECT $1, id, $2 FROM ( + SELECT namespace.id AS id, max_tables, COUNT(table_name.id) AS count + FROM namespace LEFT JOIN table_name ON namespace.id = table_name.namespace_id + WHERE namespace.id = $3 + GROUP BY namespace.max_tables, table_name.namespace_id, namespace.id +) AS get_count WHERE count < max_tables +RETURNING *; + "#, + ) + .bind(name) // $1 + .bind(partition_template) // $2 + .bind(namespace_id) // $3 + .fetch_one(&mut *tx) + .await + .map_err(|e| match e { + sqlx::Error::RowNotFound => Error::LimitExceeded { + descr: format!( + "couldn't create table {}; limit reached on namespace {}", + name, namespace_id + ), + }, + _ => { + if is_unique_violation(&e) { + Error::AlreadyExists { + descr: format!("table '{name}' in namespace {namespace_id}"), + } + } else if is_fk_violation(&e) { + Error::NotFound { + descr: e.to_string(), + } + } else { + Error::External { + source: Box::new(e), + } + } + } + })?; + + // Partitioning is only supported for tags, so create tag columns for all `TagValue` + // partition template parts. It's important this happens within the table creation + // transaction so that there isn't a possibility of a concurrent write creating these + // columns with an unsupported type. + for template_part in table.partition_template.parts() { + if let TemplatePart::TagValue(tag_name) = template_part { + insert_column_with_connection(&mut *tx, tag_name, table.id, ColumnType::Tag) + .await?; + } + } + + tx.commit().await?; + + Ok(table) + } + + async fn get_by_id(&mut self, table_id: TableId) -> Result> { + let rec = sqlx::query_as::<_, Table>( + r#" +SELECT * +FROM table_name +WHERE id = $1; + "#, + ) + .bind(table_id) // $1 + .fetch_one(self.inner.get_mut()) + .await; + + if let Err(sqlx::Error::RowNotFound) = rec { + return Ok(None); + } + + let table = rec?; + + Ok(Some(table)) + } + + async fn get_by_namespace_and_name( + &mut self, + namespace_id: NamespaceId, + name: &str, + ) -> Result> { + let rec = sqlx::query_as::<_, Table>( + r#" +SELECT * +FROM table_name +WHERE namespace_id = $1 AND name = $2; + "#, + ) + .bind(namespace_id) // $1 + .bind(name) // $2 + .fetch_one(self.inner.get_mut()) + .await; + + if let Err(sqlx::Error::RowNotFound) = rec { + return Ok(None); + } + + let table = rec?; + + Ok(Some(table)) + } + + async fn list_by_namespace_id(&mut self, namespace_id: NamespaceId) -> Result> { + let rec = sqlx::query_as::<_, Table>( + r#" +SELECT * +FROM table_name +WHERE namespace_id = $1; + "#, + ) + .bind(namespace_id) + .fetch_all(self.inner.get_mut()) + .await?; + + Ok(rec) + } + + async fn list(&mut self) -> Result> { + let rec = sqlx::query_as::<_, Table>("SELECT * FROM table_name;") + .fetch_all(self.inner.get_mut()) + .await?; + + Ok(rec) + } + + async fn snapshot(&mut self, table_id: TableId) -> Result { + let mut tx = self.inner.get_mut().pool.begin().await?; + + // This will upgrade the transaction to be exclusive + let rec = sqlx::query( + "UPDATE table_name SET generation = generation + 1 where id = $1 RETURNING *;", + ) + .bind(table_id) // $1 + .fetch_one(&mut *tx) + .await; + + if let Err(sqlx::Error::RowNotFound) = rec { + return Err(Error::NotFound { + descr: format!("table: {table_id}"), + }); + } + let row = rec?; + + let generation: i64 = row.get("generation"); + let table = Table::from_row(&row)?; + + let columns = sqlx::query_as::<_, Column>("SELECT * from column_name where table_id = $1;") + .bind(table_id) // $1 + .fetch_all(&mut *tx) + .await?; + + let partitions = + sqlx::query_as::<_, PartitionPod>("SELECT * from partition where table_id = $1;") + .bind(table_id) // $1 + .fetch_all(&mut *tx) + .await?; + + tx.commit().await?; + + Ok(TableSnapshot::encode( + table, + partitions.into_iter().map(Into::into).collect(), + columns, + generation as _, + )?) + } +} + +#[async_trait] +impl ColumnRepo for SqliteTxn { + async fn create_or_get( + &mut self, + name: &str, + table_id: TableId, + column_type: ColumnType, + ) -> Result { + insert_column_with_connection(self.inner.get_mut(), name, table_id, column_type).await + } + + async fn list_by_namespace_id(&mut self, namespace_id: NamespaceId) -> Result> { + let rec = sqlx::query_as::<_, Column>( + r#" +SELECT column_name.* FROM table_name +INNER JOIN column_name on column_name.table_id = table_name.id +WHERE table_name.namespace_id = $1; + "#, + ) + .bind(namespace_id) + .fetch_all(self.inner.get_mut()) + .await?; + + Ok(rec) + } + + async fn list_by_table_id(&mut self, table_id: TableId) -> Result> { + let rec = sqlx::query_as::<_, Column>( + r#" +SELECT * FROM column_name +WHERE table_id = $1; + "#, + ) + .bind(table_id) + .fetch_all(self.inner.get_mut()) + .await?; + + Ok(rec) + } + + async fn list(&mut self) -> Result> { + let rec = sqlx::query_as::<_, Column>("SELECT * FROM column_name;") + .fetch_all(self.inner.get_mut()) + .await?; + + Ok(rec) + } + + async fn create_or_get_many_unchecked( + &mut self, + table_id: TableId, + columns: HashMap<&str, ColumnType>, + ) -> Result> { + let num_columns = columns.len(); + #[derive(Deserialize, Serialize)] + struct NameType<'a> { + name: &'a str, + column_type: i8, + } + impl<'a> NameType<'a> { + fn from(value: (&&'a str, &ColumnType)) -> Self { + Self { + name: value.0, + column_type: *value.1 as i8, + } + } + } + let cols = columns.iter().map(NameType::<'_>::from).collect::>(); + + // The `ORDER BY` in this statement is important to avoid deadlocks during concurrent + // writes to the same IOx table that each add many new columns. See: + // + // - + // - + // - + let out = sqlx::query_as::<_, Column>( + r#" +INSERT INTO column_name ( name, table_id, column_type ) +SELECT a.value ->> 'name' AS name, $1, a.value ->> 'column_type' AS column_type +FROM json_each($2) as a +ORDER BY name +ON CONFLICT (table_id, name) +DO UPDATE SET name = column_name.name +RETURNING *; + "#, + ) + .bind(table_id) // $1 + .bind(&Json(cols)) // $2 + .fetch_all(self.inner.get_mut()) + .await + .map_err(|e| { + if is_fk_violation(&e) { + Error::NotFound { + descr: e.to_string(), + } + } else { + Error::External { + source: Box::new(e), + } + } + })?; + + assert_eq!(num_columns, out.len()); + + for existing in &out { + let want = columns.get(existing.name.as_str()).unwrap(); + ensure!( + existing.column_type == *want, + AlreadyExistsSnafu { + descr: format!( + "column {} is type {} but schema update has type {}", + existing.name, existing.column_type, want + ), + } + ); + } + + Ok(out) + } +} + +// We can't use [`Partition`], as uses Vec which the Sqlite +// driver cannot serialise + +#[derive(Debug, Clone, PartialEq, Eq, sqlx::FromRow)] +struct PartitionPod { + id: PartitionId, + hash_id: Option, + table_id: TableId, + partition_key: PartitionKey, + sort_key_ids: Json>, + new_file_at: Option, +} + +impl From for Partition { + fn from(value: PartitionPod) -> Self { + let sort_key_ids = SortKeyIds::from(value.sort_key_ids.0); + + Self::new_catalog_only( + value.id, + value.hash_id, + value.table_id, + value.partition_key, + sort_key_ids, + value.new_file_at, + ) + } +} + +#[async_trait] +impl PartitionRepo for SqliteTxn { + async fn create_or_get(&mut self, key: PartitionKey, table_id: TableId) -> Result { + // Note: since sort_key is now an array, we must explicitly insert '{}' which is an empty + // array rather than NULL which sqlx will throw `UnexpectedNullError` while is is doing + // `ColumnDecode` + + let hash_id = PartitionHashId::new(table_id, &key); + + let v = sqlx::query_as::<_, PartitionPod>( + r#" +INSERT INTO partition + (partition_key, table_id, hash_id, sort_key_ids) +VALUES + ($1, $2, $3, '[]') +ON CONFLICT (table_id, partition_key) +DO UPDATE SET partition_key = partition.partition_key +RETURNING id, hash_id, table_id, partition_key, sort_key_ids, new_file_at; + "#, + ) + .bind(key) // $1 + .bind(table_id) // $2 + .bind(&hash_id) // $3 + .fetch_one(self.inner.get_mut()) + .await + .map_err(|e| { + if is_fk_violation(&e) { + Error::NotFound { + descr: e.to_string(), + } + } else { + Error::External { + source: Box::new(e), + } + } + })?; + + Ok(v.into()) + } + + async fn get_by_id_batch(&mut self, partition_ids: &[PartitionId]) -> Result> { + // We use a JSON-based "IS IN" check. + let ids: Vec<_> = partition_ids.iter().map(|p| p.get()).collect(); + + sqlx::query_as::<_, PartitionPod>( + r#" +SELECT id, hash_id, table_id, partition_key, sort_key_ids, new_file_at +FROM partition +WHERE id IN (SELECT value FROM json_each($1)); + "#, + ) + .bind(Json(&ids[..])) // $1 + .fetch_all(self.inner.get_mut()) + .await + .map(|vals| vals.into_iter().map(Partition::from).collect()) + .map_err(Error::from) + } + + async fn list_by_table_id(&mut self, table_id: TableId) -> Result> { + Ok(sqlx::query_as::<_, PartitionPod>( + r#" +SELECT id, hash_id, table_id, partition_key, sort_key_ids, new_file_at +FROM partition +WHERE table_id = $1; + "#, + ) + .bind(table_id) // $1 + .fetch_all(self.inner.get_mut()) + .await? + .into_iter() + .map(Into::into) + .collect()) + } + + async fn list_ids(&mut self) -> Result> { + sqlx::query_as( + r#" + SELECT p.id as partition_id + FROM partition p + "#, + ) + .fetch_all(self.inner.get_mut()) + .await + .map_err(Error::from) + } + + /// Update the sort key for `partition_id` if and only if `old_sort_key` + /// matches the current value in the database. + /// + /// This compare-and-swap operation is allowed to spuriously return + /// [`CasFailure::ValueMismatch`] for performance reasons (avoiding multiple + /// round trips to service a transaction in the happy path). + async fn cas_sort_key( + &mut self, + partition_id: PartitionId, + old_sort_key_ids: Option<&SortKeyIds>, + new_sort_key_ids: &SortKeyIds, + ) -> Result> { + let old_sort_key_ids: Vec = old_sort_key_ids.map(Into::into).unwrap_or_default(); + + let raw_new_sort_key_ids: Vec = new_sort_key_ids.into(); + + // This `match` will go away when all partitions have hash IDs in the database. + let query = sqlx::query_as::<_, PartitionPod>( + r#" +UPDATE partition +SET sort_key_ids = $1 +WHERE id = $2 AND sort_key_ids = $3 +RETURNING id, hash_id, table_id, partition_key, sort_key_ids, new_file_at; + "#, + ) + .bind(Json(raw_new_sort_key_ids)) // $1 + .bind(partition_id) // $2 + .bind(Json(old_sort_key_ids)); // $3 + + let res = query.fetch_one(self.inner.get_mut()).await; + + let partition = match res { + Ok(v) => v, + Err(sqlx::Error::RowNotFound) => { + // This update may have failed either because: + // + // * A row with the specified ID did not exist at query time + // (but may exist now!) + // * The sort key does not match. + // + // To differentiate, we submit a get partition query, returning + // the actual sort key if successful. + // + // NOTE: this is racy, but documented - this might return "Sort + // key differs! Old key: " + + let partition = (self as &mut dyn PartitionRepo) + .get_by_id(partition_id) + .await + .map_err(CasFailure::QueryError)? + .ok_or(CasFailure::QueryError(Error::NotFound { + descr: partition_id.to_string(), + }))?; + return Err(CasFailure::ValueMismatch( + partition.sort_key_ids().cloned().unwrap_or_default(), + )); + } + Err(e) => { + return Err(CasFailure::QueryError(Error::External { + source: Box::new(e), + })) + } + }; + + debug!(?partition_id, "partition sort key cas successful"); + + Ok(partition.into()) + } + + async fn record_skipped_compaction( + &mut self, + partition_id: PartitionId, + reason: &str, + num_files: usize, + limit_num_files: usize, + limit_num_files_first_in_partition: usize, + estimated_bytes: u64, + limit_bytes: u64, + ) -> Result<()> { + sqlx::query( + r#" +INSERT INTO skipped_compactions + ( partition_id, reason, num_files, limit_num_files, limit_num_files_first_in_partition, estimated_bytes, limit_bytes, skipped_at ) +VALUES + ( $1, $2, $3, $4, $5, $6, $7, $8 ) +ON CONFLICT ( partition_id ) +DO UPDATE +SET +reason = EXCLUDED.reason, +num_files = EXCLUDED.num_files, +limit_num_files = EXCLUDED.limit_num_files, +limit_num_files_first_in_partition = EXCLUDED.limit_num_files_first_in_partition, +estimated_bytes = EXCLUDED.estimated_bytes, +limit_bytes = EXCLUDED.limit_bytes, +skipped_at = EXCLUDED.skipped_at; + "#, + ) + .bind(partition_id) // $1 + .bind(reason) + .bind(num_files as i64) + .bind(limit_num_files as i64) + .bind(limit_num_files_first_in_partition as i64) + .bind(estimated_bytes as i64) + .bind(limit_bytes as i64) + .bind(std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs() as i64) + .execute(self.inner.get_mut()) + .await?; + Ok(()) + } + + async fn get_in_skipped_compactions( + &mut self, + partition_ids: &[PartitionId], + ) -> Result> { + let ids = partition_ids.iter().map(|p| p.get()).collect::>(); + let rec = sqlx::query_as::( + r#"SELECT * FROM skipped_compactions WHERE partition_id IN (SELECT value FROM json_each($1));"#, + ) + .bind(Json(&ids[..])) + .fetch_all(self.inner.get_mut()) + .await; + + let skipped_partition_records = rec?; + + Ok(skipped_partition_records) + } + + async fn list_skipped_compactions(&mut self) -> Result> { + sqlx::query_as::<_, SkippedCompaction>( + r#" +SELECT * FROM skipped_compactions + "#, + ) + .fetch_all(self.inner.get_mut()) + .await + .map_err(Error::from) + } + + async fn delete_skipped_compactions( + &mut self, + partition_id: PartitionId, + ) -> Result> { + sqlx::query_as::<_, SkippedCompaction>( + r#" +DELETE FROM skipped_compactions +WHERE partition_id = $1 +RETURNING * + "#, + ) + .bind(partition_id) + .fetch_optional(self.inner.get_mut()) + .await + .map_err(Error::from) + } + + async fn most_recent_n(&mut self, n: usize) -> Result> { + Ok(sqlx::query_as::<_, PartitionPod>( + r#" +SELECT id, hash_id, table_id, partition_key, sort_key_ids, new_file_at +FROM partition +ORDER BY id DESC +LIMIT $1; + "#, + ) + .bind(n as i64) // $1 + .fetch_all(self.inner.get_mut()) + .await? + .into_iter() + .map(Into::into) + .collect()) + } + + async fn partitions_new_file_between( + &mut self, + minimum_time: Timestamp, + maximum_time: Option, + ) -> Result> { + let sql = format!( + r#" + SELECT p.id as partition_id + FROM partition p + WHERE p.new_file_at > $1 + {} + "#, + maximum_time + .map(|_| "AND p.new_file_at < $2") + .unwrap_or_default() + ); + + sqlx::query_as(&sql) + .bind(minimum_time) // $1 + .bind(maximum_time) // $2 + .fetch_all(self.inner.get_mut()) + .await + .map_err(Error::from) + } + + async fn list_old_style(&mut self) -> Result> { + Ok(sqlx::query_as::<_, PartitionPod>( + r#" +SELECT id, hash_id, table_id, partition_key, sort_key_ids, new_file_at +FROM partition +WHERE hash_id IS NULL +ORDER BY id DESC; + "#, + ) + .fetch_all(self.inner.get_mut()) + .await? + .into_iter() + .map(Into::into) + .collect()) + } + + async fn snapshot(&mut self, partition_id: PartitionId) -> Result { + let mut tx = self.inner.get_mut().pool.begin().await?; + + // This will upgrade the transaction to be exclusive + let rec = sqlx::query( + "UPDATE partition SET generation = generation + 1 where id = $1 RETURNING *;", + ) + .bind(partition_id) // $1 + .fetch_one(&mut *tx) + .await; + if let Err(sqlx::Error::RowNotFound) = rec { + return Err(Error::NotFound { + descr: format!("partition: {partition_id}"), + }); + } + let row = rec?; + + let generation: i64 = row.get("generation"); + let partition = PartitionPod::from_row(&row)?; + + let (namespace_id,): (NamespaceId,) = + sqlx::query_as("SELECT namespace_id from table_name where id = $1") + .bind(partition.table_id) // $1 + .fetch_one(&mut *tx) + .await?; + + let files = + sqlx::query_as::<_, ParquetFilePod>("SELECT * from parquet_file where partition_id = $1 AND parquet_file.to_delete IS NULL;") + .bind(partition_id) // $1 + .fetch_all(&mut *tx) + .await?; + + let sc = sqlx::query_as::( + r#"SELECT * FROM skipped_compactions WHERE partition_id = $1;"#, + ) + .bind(partition_id) + .fetch_optional(&mut *tx) + .await?; + + tx.commit().await?; + + Ok(PartitionSnapshot::encode( + namespace_id, + partition.into(), + files.into_iter().map(Into::into).collect(), + sc, + generation as _, + )?) + } +} + +fn from_column_set(v: &ColumnSet) -> Json> { + Json((*v).iter().map(ColumnId::get).collect()) +} + +fn to_column_set(v: &Json>) -> ColumnSet { + ColumnSet::new(v.0.iter().map(|v| ColumnId::new(*v))) +} + +#[derive(Debug, Clone, PartialEq, Eq, sqlx::FromRow)] +struct ParquetFilePod { + id: ParquetFileId, + namespace_id: NamespaceId, + table_id: TableId, + partition_id: PartitionId, + partition_hash_id: Option, + object_store_id: ObjectStoreId, + min_time: Timestamp, + max_time: Timestamp, + to_delete: Option, + file_size_bytes: i64, + row_count: i64, + compaction_level: CompactionLevel, + created_at: Timestamp, + column_set: Json>, + max_l0_created_at: Timestamp, +} + +impl From for ParquetFile { + fn from(value: ParquetFilePod) -> Self { + Self { + id: value.id, + namespace_id: value.namespace_id, + table_id: value.table_id, + partition_id: value.partition_id, + partition_hash_id: value.partition_hash_id, + object_store_id: value.object_store_id, + min_time: value.min_time, + max_time: value.max_time, + to_delete: value.to_delete, + file_size_bytes: value.file_size_bytes, + row_count: value.row_count, + compaction_level: value.compaction_level, + created_at: value.created_at, + column_set: to_column_set(&value.column_set), + max_l0_created_at: value.max_l0_created_at, + } + } +} + +#[async_trait] +impl ParquetFileRepo for SqliteTxn { + async fn flag_for_delete_by_retention(&mut self) -> Result> { + let flagged_at = Timestamp::from(self.time_provider.now()); + // TODO - include check of table retention period once implemented + let flagged = sqlx::query( + r#" +WITH parquet_file_ids as ( + SELECT parquet_file.object_store_id + FROM namespace, parquet_file + WHERE namespace.retention_period_ns IS NOT NULL + AND parquet_file.to_delete IS NULL + AND parquet_file.max_time < $1 - namespace.retention_period_ns + AND namespace.id = parquet_file.namespace_id + LIMIT $2 +) +UPDATE parquet_file +SET to_delete = $1 +WHERE object_store_id IN (SELECT object_store_id FROM parquet_file_ids) +RETURNING partition_id, object_store_id; + "#, + ) + .bind(flagged_at) // $1 + .bind(MAX_PARQUET_FILES_SELECTED_ONCE_FOR_RETENTION) // $2 + .fetch_all(self.inner.get_mut()) + .await?; + + let flagged = flagged + .into_iter() + .map(|row| (row.get("partition_id"), row.get("object_store_id"))) + .collect(); + Ok(flagged) + } + + async fn delete_old_ids_only(&mut self, older_than: Timestamp) -> Result> { + // see https://www.crunchydata.com/blog/simulating-update-or-delete-with-limit-in-sqlite-ctes-to-the-rescue + let deleted = sqlx::query( + r#" +WITH parquet_file_ids as ( + SELECT object_store_id + FROM parquet_file + WHERE to_delete < $1 + LIMIT $2 +) +DELETE FROM parquet_file +WHERE object_store_id IN (SELECT object_store_id FROM parquet_file_ids) +RETURNING object_store_id; + "#, + ) + .bind(older_than) // $1 + .bind(MAX_PARQUET_FILES_SELECTED_ONCE_FOR_DELETE) // $2 + .fetch_all(self.inner.get_mut()) + .await?; + + let deleted = deleted + .into_iter() + .map(|row| row.get("object_store_id")) + .collect(); + Ok(deleted) + } + + async fn list_by_partition_not_to_delete_batch( + &mut self, + partition_ids: Vec, + ) -> Result> { + // We use a JSON-based "IS IN" check. + let ids: Vec<_> = partition_ids.iter().map(|p| p.get()).collect(); + + let query = sqlx::query_as::<_, ParquetFilePod>( + r#" +SELECT parquet_file.id, namespace_id, parquet_file.table_id, partition_id, partition_hash_id, + object_store_id, min_time, max_time, parquet_file.to_delete, file_size_bytes, row_count, + compaction_level, created_at, column_set, max_l0_created_at +FROM parquet_file +WHERE parquet_file.partition_id IN (SELECT value FROM json_each($1)) + AND parquet_file.to_delete IS NULL; + "#, + ) + .bind(Json(&ids[..])); // $1 + + Ok(query + .fetch_all(self.inner.get_mut()) + .await? + .into_iter() + .map(Into::into) + .collect()) + } + + async fn get_by_object_store_id( + &mut self, + object_store_id: ObjectStoreId, + ) -> Result> { + let rec = sqlx::query_as::<_, ParquetFilePod>( + r#" +SELECT id, namespace_id, table_id, partition_id, partition_hash_id, object_store_id, min_time, + max_time, to_delete, file_size_bytes, row_count, compaction_level, created_at, column_set, + max_l0_created_at +FROM parquet_file +WHERE object_store_id = $1; + "#, + ) + .bind(object_store_id) // $1 + .fetch_one(self.inner.get_mut()) + .await; + + if let Err(sqlx::Error::RowNotFound) = rec { + return Ok(None); + } + + let parquet_file = rec?; + + Ok(Some(parquet_file.into())) + } + + async fn exists_by_object_store_id_batch( + &mut self, + object_store_ids: Vec, + ) -> Result> { + let in_value = object_store_ids + .into_iter() + // use a sqlite blob literal + .map(|id| format!("X'{}'", id.get_uuid().simple())) + .collect::>() + .join(","); + + sqlx::query(&format!( + " +SELECT object_store_id +FROM parquet_file +WHERE object_store_id IN ({v});", + v = in_value + )) + .map(|slr: SqliteRow| slr.get::("object_store_id")) + // limitation of sqlx: will not bind arrays + // https://github.com/launchbadge/sqlx/blob/main/FAQ.md#how-can-i-do-a-select--where-foo-in--query + .fetch_all(self.inner.get_mut()) + .await + .map_err(Error::from) + } + + async fn create_upgrade_delete( + &mut self, + partition_id: PartitionId, + delete: &[ObjectStoreId], + upgrade: &[ObjectStoreId], + create: &[ParquetFileParams], + target_level: CompactionLevel, + ) -> Result> { + let delete_set = delete.iter().copied().collect::>(); + let upgrade_set = upgrade.iter().copied().collect::>(); + + assert!( + delete_set.is_disjoint(&upgrade_set), + "attempted to upgrade a file scheduled for delete" + ); + let mut tx = self.inner.get_mut().pool.begin().await?; + + for id in delete { + let marked_at = Timestamp::from(self.time_provider.now()); + flag_for_delete(&mut *tx, partition_id, *id, marked_at).await?; + } + + update_compaction_level(&mut *tx, partition_id, upgrade, target_level).await?; + + let mut ids = Vec::with_capacity(create.len()); + for file in create { + if file.partition_id != partition_id { + return Err(Error::External { + source: format!("Inconsistent ParquetFileParams, expected PartitionId({partition_id}) got PartitionId({})", file.partition_id).into(), + }); + } + let res = create_parquet_file(&mut *tx, file.clone()).await?; + ids.push(res.id); + } + tx.commit().await?; + + Ok(ids) + } +} + +// The following three functions are helpers to the create_upgrade_delete method. +// They are also used by the respective create/flag_for_delete/update_compaction_level methods. +async fn create_parquet_file<'q, E>( + executor: E, + parquet_file_params: ParquetFileParams, +) -> Result +where + E: Executor<'q, Database = Sqlite>, +{ + let ParquetFileParams { + namespace_id, + table_id, + partition_id, + partition_hash_id, + object_store_id, + min_time, + max_time, + file_size_bytes, + row_count, + compaction_level, + created_at, + column_set, + max_l0_created_at, + } = parquet_file_params; + + let res = sqlx::query_as::<_, ParquetFilePod>( + r#" +INSERT INTO parquet_file ( + table_id, partition_id, partition_hash_id, object_store_id, + min_time, max_time, file_size_bytes, + row_count, compaction_level, created_at, namespace_id, column_set, max_l0_created_at ) +VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13 ) +RETURNING + id, table_id, partition_id, partition_hash_id, object_store_id, min_time, max_time, to_delete, + file_size_bytes, row_count, compaction_level, created_at, namespace_id, column_set, + max_l0_created_at; + "#, + ) + .bind(table_id) // $1 + .bind(partition_id) // $2 + .bind(partition_hash_id.as_ref()) // $3 + .bind(object_store_id) // $4 + .bind(min_time) // $5 + .bind(max_time) // $6 + .bind(file_size_bytes) // $7 + .bind(row_count) // $8 + .bind(compaction_level) // $9 + .bind(created_at) // $10 + .bind(namespace_id) // $11 + .bind(from_column_set(&column_set)) // $12 + .bind(max_l0_created_at) // $13 + .fetch_one(executor) + .await; + + let rec = res.map_err(|e| { + if is_unique_violation(&e) { + Error::AlreadyExists { + descr: object_store_id.to_string(), + } + } else if is_fk_violation(&e) { + Error::NotFound { + descr: e.to_string(), + } + } else { + Error::External { + source: Box::new(e), + } + } + })?; + + Ok(rec.into()) +} + +async fn flag_for_delete<'q, E>( + executor: E, + partition_id: PartitionId, + id: ObjectStoreId, + marked_at: Timestamp, +) -> Result<()> +where + E: Executor<'q, Database = Sqlite>, +{ + let updated = + sqlx::query_as::<_, (i64,)>(r#"UPDATE parquet_file SET to_delete = $1 WHERE object_store_id = $2 AND partition_id = $3 AND to_delete is NULL returning id;"#) + .bind(marked_at) // $1 + .bind(id) // $2 + .bind(partition_id) // $3 + .fetch_all(executor) + .await?; + + if updated.len() != 1 { + return Err(Error::NotFound { + descr: format!("parquet file {id} not found for delete"), + }); + } + + Ok(()) +} + +async fn update_compaction_level<'q, E>( + executor: E, + partition_id: PartitionId, + object_store_ids: &[ObjectStoreId], + compaction_level: CompactionLevel, +) -> Result<()> +where + E: Executor<'q, Database = Sqlite>, +{ + let in_value = object_store_ids + .iter() + // use a sqlite blob literal + .map(|id| format!("X'{}'", id.get_uuid().simple())) + .collect::>() + .join(","); + + let updated = sqlx::query_as::<_, (i64,)>(&format!( + r#" +UPDATE parquet_file +SET compaction_level = $1 +WHERE object_store_id IN ({v}) AND partition_id = $2 AND to_delete is NULL returning id; + "#, + v = in_value, + )) + .bind(compaction_level) // $1 + .bind(partition_id) // $2 + .fetch_all(executor) + .await?; + + if updated.len() != object_store_ids.len() { + return Err(Error::NotFound { + descr: "parquet file(s) not found for upgrade".to_string(), + }); + } + + Ok(()) +} + +/// The error code returned by SQLite for a unique constraint violation. +/// +/// See +const SQLITE_UNIQUE_VIOLATION: &str = "2067"; + +/// Error code returned by SQLite for a foreign key constraint violation. +/// See +const SQLITE_FK_VIOLATION: &str = "787"; + +fn is_fk_violation(e: &sqlx::Error) -> bool { + if let sqlx::Error::Database(inner) = e { + if let Some(code) = inner.code() { + if code == SQLITE_FK_VIOLATION { + return true; + } + } + } + + false +} + +/// Returns true if `e` is a unique constraint violation error. +fn is_unique_violation(e: &sqlx::Error) -> bool { + if let sqlx::Error::Database(inner) = e { + if let Some(code) = inner.code() { + if code == SQLITE_UNIQUE_VIOLATION { + return true; + } + } + } + + false +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::interface::ParquetFileRepoExt; + use crate::test_helpers::{ + arbitrary_namespace, arbitrary_parquet_file_params, arbitrary_table, + }; + use assert_matches::assert_matches; + use data_types::partition_template::TemplatePart; + use generated_types::influxdata::iox::partition_template::v1 as proto; + use std::sync::Arc; + + async fn setup_db() -> SqliteCatalog { + let dsn = + std::env::var("TEST_INFLUXDB_SQLITE_DSN").unwrap_or("sqlite::memory:".to_string()); + let options = SqliteConnectionOptions { file_path: dsn }; + let metrics = Arc::new(Registry::default()); + let cat = SqliteCatalog::connect(options, metrics) + .await + .expect("failed to connect to catalog"); + cat.setup().await.expect("failed to initialise database"); + cat + } + + #[tokio::test] + async fn test_catalog() { + crate::interface_tests::test_catalog(|| async { + let sqlite = setup_db().await; + let sqlite: Arc = Arc::new(sqlite); + sqlite + }) + .await; + } + + #[tokio::test] + async fn existing_partitions_without_hash_id() { + let sqlite: SqliteCatalog = setup_db().await; + let pool = sqlite.pool.clone(); + let sqlite: Arc = Arc::new(sqlite); + let mut repos = sqlite.repositories(); + + let namespace = arbitrary_namespace(&mut *repos, "ns4").await; + let table = arbitrary_table(&mut *repos, "table", &namespace).await; + let table_id = table.id; + let key = PartitionKey::from("francis-scott-key-key"); + + // Create a partition record in the database that has `NULL` for its `hash_id` + // value, which is what records existing before the migration adding that column will have. + sqlx::query( + r#" +INSERT INTO partition + (partition_key, table_id, sort_key_ids) +VALUES + ($1, $2, '[]') +ON CONFLICT (table_id, partition_key) +DO UPDATE SET partition_key = partition.partition_key +RETURNING id, hash_id, table_id, partition_key, sort_key_ids, new_file_at; + "#, + ) + .bind(&key) // $1 + .bind(table_id) // $2 + .fetch_one(&pool) + .await + .unwrap(); + + // Check that the hash_id being null in the database doesn't break querying for partitions. + let table_partitions = repos.partitions().list_by_table_id(table_id).await.unwrap(); + assert_eq!(table_partitions.len(), 1); + let partition = &table_partitions[0]; + + // Call create_or_get for the same (key, table_id) pair, to ensure the write is idempotent + // and that the hash_id still doesn't get set. + let inserted_again = repos + .partitions() + .create_or_get(key, table_id) + .await + .expect("idempotent write should succeed"); + + // Test: sort_key_ids from freshly insert with empty value + assert!(inserted_again.sort_key_ids().is_none()); + + assert_eq!(partition, &inserted_again); + + // Create a Parquet file record in this partition to ensure we don't break new data + // ingestion for old-style partitions + let parquet_file_params = arbitrary_parquet_file_params(&namespace, &table, partition); + let parquet_file = repos + .parquet_files() + .create(parquet_file_params) + .await + .unwrap(); + assert_eq!(parquet_file.partition_hash_id, None); + + // Add a partition record WITH a hash ID + repos + .partitions() + .create_or_get(PartitionKey::from("Something else"), table_id) + .await + .unwrap(); + + // Ensure we can list only the old-style partitions + let old_style_partitions = repos.partitions().list_old_style().await.unwrap(); + assert_eq!(old_style_partitions.len(), 1); + assert_eq!(old_style_partitions[0].id, partition.id); + } + + #[tokio::test] + async fn test_billing_summary_on_parqet_file_creation() { + let sqlite = setup_db().await; + let pool = sqlite.pool.clone(); + let sqlite: Arc = Arc::new(sqlite); + let mut repos = sqlite.repositories(); + let namespace = arbitrary_namespace(&mut *repos, "ns4").await; + let table = arbitrary_table(&mut *repos, "table", &namespace).await; + let key = "bananas"; + let partition = repos + .partitions() + .create_or_get(key.into(), table.id) + .await + .unwrap(); + + // parquet file to create- all we care about here is the size + let mut p1 = arbitrary_parquet_file_params(&namespace, &table, &partition); + p1.file_size_bytes = 1337; + let f1 = repos + .parquet_files() + .create(p1.clone()) + .await + .expect("create parquet file should succeed"); + // insert the same again with a different size; we should then have 3x1337 as total file + // size + p1.object_store_id = ObjectStoreId::new(); + p1.file_size_bytes *= 2; + let _f2 = repos + .parquet_files() + .create(p1.clone()) + .await + .expect("create parquet file should succeed"); + + // after adding two files we should have 3x1337 in the summary + let total_file_size_bytes: i64 = + sqlx::query_scalar("SELECT total_file_size_bytes FROM billing_summary;") + .fetch_one(&pool) + .await + .expect("fetch total file size failed"); + assert_eq!(total_file_size_bytes, 1337 * 3); + + // flag f1 for deletion and assert that the total file size is reduced accordingly. + repos + .parquet_files() + .create_upgrade_delete( + partition.id, + &[f1.object_store_id], + &[], + &[], + CompactionLevel::Initial, + ) + .await + .expect("flag parquet file for deletion should succeed"); + let total_file_size_bytes: i64 = + sqlx::query_scalar("SELECT total_file_size_bytes FROM billing_summary;") + .fetch_one(&pool) + .await + .expect("fetch total file size failed"); + // we marked the first file of size 1337 for deletion leaving only the second that was 2x that + assert_eq!(total_file_size_bytes, 1337 * 2); + + // actually deleting shouldn't change the total + let older_than = p1.created_at + 1; + repos + .parquet_files() + .delete_old_ids_only(older_than) + .await + .expect("parquet file deletion should succeed"); + let total_file_size_bytes: i64 = + sqlx::query_scalar("SELECT total_file_size_bytes FROM billing_summary;") + .fetch_one(&pool) + .await + .expect("fetch total file size failed"); + assert_eq!(total_file_size_bytes, 1337 * 2); + } + + #[tokio::test] + async fn namespace_partition_template_null_is_the_default_in_the_database() { + let sqlite = setup_db().await; + let pool = sqlite.pool.clone(); + let sqlite: Arc = Arc::new(sqlite); + let mut repos = sqlite.repositories(); + + let namespace_name = "apples"; + + // Create a namespace record in the database that has `NULL` for its `partition_template` + // value, which is what records existing before the migration adding that column will have. + let insert_null_partition_template_namespace = sqlx::query( + r#" +INSERT INTO namespace ( + name, retention_period_ns, partition_template +) +VALUES ( $1, $2, NULL ) +RETURNING id, name, retention_period_ns, max_tables, max_columns_per_table, deleted_at, + partition_template; + "#, + ) + .bind(namespace_name) // $1 + .bind(None::>); // $2 + + insert_null_partition_template_namespace + .fetch_one(&pool) + .await + .unwrap(); + + let lookup_namespace = repos + .namespaces() + .get_by_name(namespace_name, SoftDeletedRows::ExcludeDeleted) + .await + .unwrap() + .unwrap(); + // When fetching this namespace from the database, the `FromRow` impl should set its + // `partition_template` to the default. + assert_eq!( + lookup_namespace.partition_template, + NamespacePartitionTemplateOverride::default() + ); + + // When creating a namespace through the catalog functions without specifying a custom + // partition template, + let created_without_custom_template = repos + .namespaces() + .create( + &"lemons".try_into().unwrap(), + None, // no partition template + None, + None, + ) + .await + .unwrap(); + + // it should have the default template in the application, + assert_eq!( + created_without_custom_template.partition_template, + NamespacePartitionTemplateOverride::default() + ); + + // and store NULL in the database record. + let record = sqlx::query("SELECT name, partition_template FROM namespace WHERE id = $1;") + .bind(created_without_custom_template.id) + .fetch_one(&pool) + .await + .unwrap(); + let name: String = record.try_get("name").unwrap(); + assert_eq!(created_without_custom_template.name, name); + let partition_template: Option = + record.try_get("partition_template").unwrap(); + assert!(partition_template.is_none()); + + // When explicitly setting a template that happens to be equal to the application default, + // assume it's important that it's being specially requested and store it rather than NULL. + let namespace_custom_template_name = "kumquats"; + let custom_partition_template_equal_to_default = + NamespacePartitionTemplateOverride::try_from(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat( + "%Y-%m-%d".to_owned(), + )), + }], + }) + .unwrap(); + let namespace_custom_template = repos + .namespaces() + .create( + &namespace_custom_template_name.try_into().unwrap(), + Some(custom_partition_template_equal_to_default.clone()), + None, + None, + ) + .await + .unwrap(); + assert_eq!( + namespace_custom_template.partition_template, + custom_partition_template_equal_to_default + ); + let record = sqlx::query("SELECT name, partition_template FROM namespace WHERE id = $1;") + .bind(namespace_custom_template.id) + .fetch_one(&pool) + .await + .unwrap(); + let name: String = record.try_get("name").unwrap(); + assert_eq!(namespace_custom_template.name, name); + let partition_template: Option = + record.try_get("partition_template").unwrap(); + assert_eq!( + partition_template.unwrap(), + custom_partition_template_equal_to_default + ); + } + + #[tokio::test] + async fn table_partition_template_null_is_the_default_in_the_database() { + let sqlite = setup_db().await; + let pool = sqlite.pool.clone(); + let sqlite: Arc = Arc::new(sqlite); + let mut repos = sqlite.repositories(); + + let namespace_default_template_name = "oranges"; + let namespace_default_template = repos + .namespaces() + .create( + &namespace_default_template_name.try_into().unwrap(), + None, // no partition template + None, + None, + ) + .await + .unwrap(); + + let namespace_custom_template_name = "limes"; + let namespace_custom_template = repos + .namespaces() + .create( + &namespace_custom_template_name.try_into().unwrap(), + Some( + NamespacePartitionTemplateOverride::try_from(proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TimeFormat("year-%Y".into())), + }], + }) + .unwrap(), + ), + None, + None, + ) + .await + .unwrap(); + + // In a namespace that also has a NULL template, create a table record in the database that + // has `NULL` for its `partition_template` value, which is what records existing before the + // migration adding that column will have. + let table_name = "null_template"; + let insert_null_partition_template_table = sqlx::query( + r#" +INSERT INTO table_name ( name, namespace_id, partition_template ) +VALUES ( $1, $2, NULL ) +RETURNING *; + "#, + ) + .bind(table_name) // $1 + .bind(namespace_default_template.id); // $2 + + insert_null_partition_template_table + .fetch_one(&pool) + .await + .unwrap(); + + let lookup_table = repos + .tables() + .get_by_namespace_and_name(namespace_default_template.id, table_name) + .await + .unwrap() + .unwrap(); + // When fetching this table from the database, the `FromRow` impl should set its + // `partition_template` to the system default (because the namespace didn't have a template + // either). + assert_eq!( + lookup_table.partition_template, + TablePartitionTemplateOverride::default() + ); + + // In a namespace that has a custom template, create a table record in the database that + // has `NULL` for its `partition_template` value. + // + // THIS ACTUALLY SHOULD BE IMPOSSIBLE because: + // + // * Namespaces have to exist before tables + // * `partition_tables` are immutable on both namespaces and tables + // * When the migration adding the `partition_table` column is deployed, namespaces can + // begin to be created with `partition_templates` + // * *Then* tables can be created with `partition_templates` or not + // * When tables don't get a custom table partition template but their namespace has one, + // their database record will get the namespace partition template. + // + // In other words, table `partition_template` values in the database is allowed to possibly + // be `NULL` IFF their namespace's `partition_template` is `NULL`. + // + // That said, this test creates this hopefully-impossible scenario to ensure that the + // defined, expected behavior if a table record somehow exists in the database with a `NULL` + // `partition_template` value is that it will have the application default partition + // template *even if the namespace `partition_template` is not null*. + let table_name = "null_template"; + let insert_null_partition_template_table = sqlx::query( + r#" +INSERT INTO table_name ( name, namespace_id, partition_template ) +VALUES ( $1, $2, NULL ) +RETURNING *; + "#, + ) + .bind(table_name) // $1 + .bind(namespace_custom_template.id); // $2 + + insert_null_partition_template_table + .fetch_one(&pool) + .await + .unwrap(); + + let lookup_table = repos + .tables() + .get_by_namespace_and_name(namespace_custom_template.id, table_name) + .await + .unwrap() + .unwrap(); + // When fetching this table from the database, the `FromRow` impl should set its + // `partition_template` to the system default *even though the namespace has a + // template*, because this should be impossible as detailed above. + assert_eq!( + lookup_table.partition_template, + TablePartitionTemplateOverride::default() + ); + + // # Table template false, namespace template true + // + // When creating a table through the catalog functions *without* a custom table template in + // a namespace *with* a custom partition template, + let table_no_template_with_namespace_template = repos + .tables() + .create( + "pomelo", + TablePartitionTemplateOverride::try_new( + None, // no custom partition template + &namespace_custom_template.partition_template, + ) + .unwrap(), + namespace_custom_template.id, + ) + .await + .unwrap(); + + // it should have the namespace's template + assert_eq!( + table_no_template_with_namespace_template.partition_template, + TablePartitionTemplateOverride::try_new( + None, + &namespace_custom_template.partition_template + ) + .unwrap() + ); + + // and store that value in the database record. + let record = sqlx::query("SELECT name, partition_template FROM table_name WHERE id = $1;") + .bind(table_no_template_with_namespace_template.id) + .fetch_one(&pool) + .await + .unwrap(); + let name: String = record.try_get("name").unwrap(); + assert_eq!(table_no_template_with_namespace_template.name, name); + let partition_template: Option = + record.try_get("partition_template").unwrap(); + assert_eq!( + partition_template.unwrap(), + TablePartitionTemplateOverride::try_new( + None, + &namespace_custom_template.partition_template + ) + .unwrap() + ); + + // # Table template true, namespace template false + // + // When creating a table through the catalog functions *with* a custom table template in + // a namespace *without* a custom partition template, + let custom_table_template = proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("chemical".into())), + }], + }; + let table_with_template_no_namespace_template = repos + .tables() + .create( + "tangerine", + TablePartitionTemplateOverride::try_new( + Some(custom_table_template), // with custom partition template + &namespace_default_template.partition_template, + ) + .unwrap(), + namespace_default_template.id, + ) + .await + .unwrap(); + + // it should have the custom table template + let table_template_parts: Vec<_> = table_with_template_no_namespace_template + .partition_template + .parts() + .collect(); + assert_eq!(table_template_parts.len(), 1); + assert_matches!( + table_template_parts[0], + TemplatePart::TagValue(tag) if tag == "chemical" + ); + + // and store that value in the database record. + let record = sqlx::query("SELECT name, partition_template FROM table_name WHERE id = $1;") + .bind(table_with_template_no_namespace_template.id) + .fetch_one(&pool) + .await + .unwrap(); + let name: String = record.try_get("name").unwrap(); + assert_eq!(table_with_template_no_namespace_template.name, name); + let partition_template = record + .try_get::, _>("partition_template") + .unwrap() + .unwrap(); + let table_template_parts: Vec<_> = partition_template.parts().collect(); + assert_eq!(table_template_parts.len(), 1); + assert_matches!( + table_template_parts[0], + TemplatePart::TagValue(tag) if tag == "chemical" + ); + + // # Table template true, namespace template true + // + // When creating a table through the catalog functions *with* a custom table template in + // a namespace *with* a custom partition template, + let custom_table_template = proto::PartitionTemplate { + parts: vec![proto::TemplatePart { + part: Some(proto::template_part::Part::TagValue("vegetable".into())), + }], + }; + let table_with_template_with_namespace_template = repos + .tables() + .create( + "nectarine", + TablePartitionTemplateOverride::try_new( + Some(custom_table_template), // with custom partition template + &namespace_custom_template.partition_template, + ) + .unwrap(), + namespace_custom_template.id, + ) + .await + .unwrap(); + + // it should have the custom table template + let table_template_parts: Vec<_> = table_with_template_with_namespace_template + .partition_template + .parts() + .collect(); + assert_eq!(table_template_parts.len(), 1); + assert_matches!( + table_template_parts[0], + TemplatePart::TagValue(tag) if tag == "vegetable" + ); + + // and store that value in the database record. + let record = sqlx::query("SELECT name, partition_template FROM table_name WHERE id = $1;") + .bind(table_with_template_with_namespace_template.id) + .fetch_one(&pool) + .await + .unwrap(); + let name: String = record.try_get("name").unwrap(); + assert_eq!(table_with_template_with_namespace_template.name, name); + let partition_template = record + .try_get::, _>("partition_template") + .unwrap() + .unwrap(); + let table_template_parts: Vec<_> = partition_template.parts().collect(); + assert_eq!(table_template_parts.len(), 1); + assert_matches!( + table_template_parts[0], + TemplatePart::TagValue(tag) if tag == "vegetable" + ); + + // # Table template false, namespace template false + // + // When creating a table through the catalog functions *without* a custom table template in + // a namespace *without* a custom partition template, + let table_no_template_no_namespace_template = repos + .tables() + .create( + "grapefruit", + TablePartitionTemplateOverride::try_new( + None, // no custom partition template + &namespace_default_template.partition_template, + ) + .unwrap(), + namespace_default_template.id, + ) + .await + .unwrap(); + + // it should have the default template in the application, + assert_eq!( + table_no_template_no_namespace_template.partition_template, + TablePartitionTemplateOverride::default() + ); + + // and store NULL in the database record. + let record = sqlx::query("SELECT name, partition_template FROM table_name WHERE id = $1;") + .bind(table_no_template_no_namespace_template.id) + .fetch_one(&pool) + .await + .unwrap(); + let name: String = record.try_get("name").unwrap(); + assert_eq!(table_no_template_no_namespace_template.name, name); + let partition_template: Option = + record.try_get("partition_template").unwrap(); + assert!(partition_template.is_none()); + } +} diff --git a/iox_catalog/src/test_helpers.rs b/iox_catalog/src/test_helpers.rs new file mode 100644 index 0000000..0861d79 --- /dev/null +++ b/iox_catalog/src/test_helpers.rs @@ -0,0 +1,92 @@ +//! Catalog helper functions for creation of catalog objects +use data_types::{ + partition_template::TablePartitionTemplateOverride, ColumnId, ColumnSet, CompactionLevel, + Namespace, NamespaceName, ObjectStoreId, ParquetFileParams, Partition, Table, TableSchema, + Timestamp, +}; + +use crate::interface::RepoCollection; + +/// When the details of the namespace don't matter; the test just needs *a* catalog namespace +/// with a particular name. +/// +/// Use [`NamespaceRepo::create`] directly if: +/// +/// - The values of the parameters to `create` need to be different than what's here +/// - The values of the parameters to `create` are relevant to the behavior under test +/// - You expect namespace creation to fail in the test +/// +/// [`NamespaceRepo::create`]: crate::interface::NamespaceRepo::create +pub async fn arbitrary_namespace( + repos: &mut R, + name: &str, +) -> Namespace { + let namespace_name = NamespaceName::new(name).unwrap(); + repos + .namespaces() + .create(&namespace_name, None, None, None) + .await + .unwrap() +} + +/// When the details of the table don't matter; the test just needs *a* catalog table +/// with a particular name in a particular namespace. +/// +/// Use [`TableRepo::create`] directly if: +/// +/// - The values of the parameters to `create_or_get` need to be different than what's here +/// - The values of the parameters to `create_or_get` are relevant to the behavior under test +/// - You expect table creation to fail in the test +/// +/// [`TableRepo::create`]: crate::interface::TableRepo::create +pub async fn arbitrary_table( + repos: &mut R, + name: &str, + namespace: &Namespace, +) -> Table { + repos + .tables() + .create( + name, + TablePartitionTemplateOverride::try_new(None, &namespace.partition_template).unwrap(), + namespace.id, + ) + .await + .unwrap() +} + +/// Load or create an arbitrary table schema in the same way that a write implicitly creates a +/// table, that is, with a time column. +pub async fn arbitrary_table_schema_load_or_create( + repos: &mut R, + name: &str, + namespace: &Namespace, +) -> TableSchema { + crate::util::table_load_or_create(repos, namespace.id, &namespace.partition_template, name) + .await + .unwrap() +} + +/// When the details of a Parquet file record don't matter, the test just needs *a* Parquet +/// file record in a particular namespace+table+partition. +pub fn arbitrary_parquet_file_params( + namespace: &Namespace, + table: &Table, + partition: &Partition, +) -> ParquetFileParams { + ParquetFileParams { + namespace_id: namespace.id, + table_id: table.id, + partition_id: partition.id, + partition_hash_id: partition.hash_id().cloned(), + object_store_id: ObjectStoreId::new(), + min_time: Timestamp::new(1), + max_time: Timestamp::new(10), + file_size_bytes: 1337, + row_count: 0, + compaction_level: CompactionLevel::Initial, + created_at: Timestamp::new(1), + column_set: ColumnSet::new([ColumnId::new(1), ColumnId::new(2)]), + max_l0_created_at: Timestamp::new(1), + } +} diff --git a/iox_catalog/src/util.rs b/iox_catalog/src/util.rs new file mode 100644 index 0000000..d6d184f --- /dev/null +++ b/iox_catalog/src/util.rs @@ -0,0 +1,897 @@ +//! Helper methods to simplify catalog work. +//! +//! They all use the public [`Catalog`] interface and have no special access to internals, so in theory they can be +//! implement downstream as well. + +use std::{ + borrow::Cow, + collections::{BTreeMap, HashMap, HashSet}, + sync::Arc, +}; + +use data_types::{ + partition_template::{NamespacePartitionTemplateOverride, TablePartitionTemplateOverride}, + ColumnType, ColumnsByName, Namespace, NamespaceId, NamespaceSchema, PartitionId, SortKeyIds, + TableId, TableSchema, +}; +use mutable_batch::MutableBatch; +use thiserror::Error; + +use crate::{ + constants::TIME_COLUMN, + interface::{CasFailure, Catalog, Error, RepoCollection, SoftDeletedRows}, +}; + +/// Gets the namespace schema including all tables and columns. +pub async fn get_schema_by_id( + id: NamespaceId, + repos: &mut R, + deleted: SoftDeletedRows, +) -> Result, crate::interface::Error> +where + R: RepoCollection + ?Sized, +{ + let Some(namespace) = repos.namespaces().get_by_id(id, deleted).await? else { + return Ok(None); + }; + + Ok(Some(get_schema_internal(namespace, repos).await?)) +} + +/// Gets the namespace schema including all tables and columns. +pub async fn get_schema_by_name( + name: &str, + repos: &mut R, + deleted: SoftDeletedRows, +) -> Result, crate::interface::Error> +where + R: RepoCollection + ?Sized, +{ + let Some(namespace) = repos.namespaces().get_by_name(name, deleted).await? else { + return Ok(None); + }; + + Ok(Some(get_schema_internal(namespace, repos).await?)) +} + +async fn get_schema_internal( + namespace: Namespace, + repos: &mut R, +) -> Result +where + R: RepoCollection + ?Sized, +{ + // get the columns first just in case someone else is creating schema while we're doing this. + let columns = repos.columns().list_by_namespace_id(namespace.id).await?; + let tables = repos.tables().list_by_namespace_id(namespace.id).await?; + + let mut namespace = NamespaceSchema::new_empty_from(&namespace); + + let mut table_id_to_schema = BTreeMap::new(); + for t in tables { + let table_schema = TableSchema::new_empty_from(&t); + table_id_to_schema.insert(t.id, (t.name, table_schema)); + } + + for c in columns { + let (_, t) = table_id_to_schema.get_mut(&c.table_id).unwrap(); + t.add_column(c); + } + + for (_, (table_name, schema)) in table_id_to_schema { + namespace.tables.insert(table_name, schema); + } + + Ok(namespace) +} + +/// Gets the schema for one particular table in a namespace. +pub async fn get_schema_by_namespace_and_table( + name: &str, + table_name: &str, + repos: &mut R, + deleted: SoftDeletedRows, +) -> Result, crate::interface::Error> +where + R: RepoCollection + ?Sized, +{ + let Some(namespace) = repos.namespaces().get_by_name(name, deleted).await? else { + return Ok(None); + }; + + let Some(table) = repos + .tables() + .get_by_namespace_and_name(namespace.id, table_name) + .await? + else { + return Ok(None); + }; + + let mut table_schema = TableSchema::new_empty_from(&table); + + let columns = repos.columns().list_by_table_id(table.id).await?; + for c in columns { + table_schema.add_column(c); + } + + let mut namespace = NamespaceSchema::new_empty_from(&namespace); + namespace + .tables + .insert(table_name.to_string(), table_schema); + + Ok(Some(namespace)) +} + +/// Gets all the table's columns. +pub async fn get_table_columns_by_id( + id: TableId, + repos: &mut R, +) -> Result +where + R: RepoCollection + ?Sized, +{ + let columns = repos.columns().list_by_table_id(id).await?; + + Ok(ColumnsByName::new(columns)) +} + +/// Fetch all [`NamespaceSchema`] in the catalog. +/// +/// This method performs the minimal number of queries needed to build the +/// result set. No table lock is obtained, nor are queries executed within a +/// transaction, but this method does return a point-in-time snapshot of the +/// catalog state. +/// +/// # Soft Deletion +/// +/// No schemas for soft-deleted namespaces are returned. +pub async fn list_schemas( + catalog: &dyn Catalog, +) -> Result, crate::interface::Error> { + let mut repos = catalog.repositories(); + + // In order to obtain a point-in-time snapshot, first fetch the columns, + // then the tables, and then resolve the namespace IDs to Namespace in order + // to construct the schemas. + // + // The set of columns returned forms the state snapshot, with the subsequent + // queries resolving only what is needed to construct schemas for the + // retrieved columns (ignoring any newly added tables/namespaces since the + // column snapshot was taken). + // + // This approach also tolerates concurrently deleted namespaces, which are + // simply ignored at the end when joining to the namespace query result. + + // First fetch all the columns - this is the state snapshot of the catalog + // schemas. + let columns = repos.columns().list().await?; + + // Construct the set of table IDs these columns belong to. + let retain_table_ids = columns.iter().map(|c| c.table_id).collect::>(); + + // Fetch all tables, and filter for those that are needed to construct + // schemas for "columns" only. + // + // Discard any tables that have no columns or have been created since + // the "columns" snapshot was retrieved, and construct a map of ID->Table. + let tables = repos + .tables() + .list() + .await? + .into_iter() + .filter_map(|t| { + if !retain_table_ids.contains(&t.id) { + return None; + } + + Some((t.id, t)) + }) + .collect::>(); + + // Drop the table ID set as it will not be referenced again. + drop(retain_table_ids); + + // Do all the I/O to fetch the namespaces in the background, while this + // thread constructs the NamespaceId->TableSchema map below. + let namespaces = tokio::spawn(async move { + repos + .namespaces() + .list(SoftDeletedRows::ExcludeDeleted) + .await + }); + + // A set of tables within a single namespace. + type NamespaceTables = BTreeMap; + + let mut joined = HashMap::::default(); + for column in columns { + // Resolve the table this column references + let table = tables.get(&column.table_id).expect("no table for column"); + + let table_schema = joined + // Find or create a record in the joined map + // for this namespace ID. + .entry(table.namespace_id) + .or_default() + // Fetch the schema record for this table, or create an empty one. + .entry(table.name.clone()) + .or_insert_with(|| TableSchema::new_empty_from(table)); + + table_schema.add_column(column); + } + + // The table map is no longer needed - immediately reclaim the memory. + drop(tables); + + // Convert the Namespace instances into NamespaceSchema instances. + let iter = namespaces + .await + .expect("namespace list task panicked")? + .into_iter() + // Ignore any namespaces that did not exist when the "columns" snapshot + // was created, or have no tables/columns (and therefore have no entry + // in "joined"). + .filter_map(move |v| { + // The catalog call explicitly asked for no soft deleted records. + assert!(v.deleted_at.is_none()); + + let mut ns = NamespaceSchema::new_empty_from(&v); + + ns.tables = joined.remove(&v.id)?; + Some((v, ns)) + }); + + Ok(iter) +} + +/// In a backoff loop, retry calling the compare-and-swap sort key catalog function if the catalog +/// returns a query error unrelated to the CAS operation. +/// +/// Returns with a value of `Ok` containing the new sort key if: +/// +/// - No concurrent updates were detected +/// - A concurrent update was detected, but the other update resulted in the same value this update +/// was attempting to set +/// +/// Returns with a value of `Err(newly_observed_value)` if a concurrent, conflicting update was +/// detected. It is expected that callers of this function will take the returned value into +/// account (in whatever manner is appropriate) before calling this function again. +/// +/// NOTE: it is expected that ONLY processes that ingest data (currently only the ingesters or the +/// bulk ingest API) update sort keys for existing partitions. Consider how calling this function +/// from new processes will interact with the existing calls. +pub async fn retry_cas_sort_key( + old_sort_key_ids: Option<&SortKeyIds>, + new_sort_key_ids: &SortKeyIds, + partition_id: PartitionId, + catalog: Arc, +) -> Result { + use backoff::Backoff; + use observability_deps::tracing::{info, warn}; + use std::ops::ControlFlow; + + Backoff::new(&Default::default()) + .retry_with_backoff("cas_sort_key", || { + let new_sort_key_ids = new_sort_key_ids.clone(); + let catalog = Arc::clone(&catalog); + async move { + let mut repos = catalog.repositories(); + match repos + .partitions() + .cas_sort_key(partition_id, old_sort_key_ids, &new_sort_key_ids) + .await + { + Ok(_) => ControlFlow::Break(Ok(new_sort_key_ids)), + Err(CasFailure::QueryError(e)) => ControlFlow::Continue(e), + Err(CasFailure::ValueMismatch(observed_sort_key_ids)) + if observed_sort_key_ids == new_sort_key_ids => + { + // A CAS failure occurred because of a concurrent + // sort key update, however the new catalog sort key + // exactly matches the sort key this node wants to + // commit. + // + // This is the sad-happy path, and this task can + // continue. + info!( + %partition_id, + ?old_sort_key_ids, + ?observed_sort_key_ids, + update_sort_key_ids=?new_sort_key_ids, + "detected matching concurrent sort key update" + ); + ControlFlow::Break(Ok(new_sort_key_ids)) + } + Err(CasFailure::ValueMismatch(observed_sort_key_ids)) => { + // Another ingester concurrently updated the sort + // key. + // + // This breaks a sort-key update invariant - sort + // key updates MUST be serialised. This operation must + // be retried. + // + // See: + // https://github.com/influxdata/influxdb_iox/issues/6439 + // + warn!( + %partition_id, + ?old_sort_key_ids, + ?observed_sort_key_ids, + update_sort_key_ids=?new_sort_key_ids, + "detected concurrent sort key update" + ); + // Stop the retry loop with an error containing the + // newly observed sort key. + ControlFlow::Break(Err(observed_sort_key_ids)) + } + } + } + }) + .await + .expect("retry forever") +} + +/// An [`crate::interface::Error`] scoped to a single table for schema validation errors. +#[derive(Debug, Error)] +#[error("table {}, {}", .0, .1)] +pub struct TableScopedError(String, Error); + +impl TableScopedError { + /// Return the table name for this error. + pub fn table(&self) -> &str { + &self.0 + } + + /// Return a reference to the error. + pub fn err(&self) -> &Error { + &self.1 + } + + /// Return ownership of the error, discarding the table name. + pub fn into_err(self) -> Error { + self.1 + } +} + +/// Given an iterator of `(table_name, batch)` to validate, this function +/// ensures all the columns within `batch` match the existing schema for +/// `table_name` in `schema`. If the column does not already exist in `schema`, +/// it is created and an updated [`NamespaceSchema`] is returned. +/// +/// This function pushes schema additions through to the backend catalog, and +/// relies on the catalog to serialize concurrent additions of a given column, +/// ensuring only one type is ever accepted per column. +pub async fn validate_or_insert_schema<'a, T, U, R>( + tables: T, + schema: &NamespaceSchema, + repos: &mut R, +) -> Result, TableScopedError> +where + T: IntoIterator + Send + Sync, + U: Iterator + Send, + R: RepoCollection + ?Sized, +{ + let tables = tables.into_iter(); + + // The (potentially updated) NamespaceSchema to return to the caller. + let mut schema = Cow::Borrowed(schema); + + for (table_name, batch) in tables { + validate_mutable_batch(batch, table_name, &mut schema, repos).await?; + } + + match schema { + Cow::Owned(v) => Ok(Some(v)), + Cow::Borrowed(_) => Ok(None), + } +} + +// &mut Cow is used to avoid a copy, so allow it +#[allow(clippy::ptr_arg)] +async fn validate_mutable_batch( + mb: &MutableBatch, + table_name: &str, + schema: &mut Cow<'_, NamespaceSchema>, + repos: &mut R, +) -> Result<(), TableScopedError> +where + R: RepoCollection + ?Sized, +{ + // Check if the table exists in the schema. + // + // Because the entry API requires &mut it is not used to avoid a premature + // clone of the Cow. + let mut table = match schema.tables.get(table_name) { + Some(t) => Cow::Borrowed(t), + None => { + // The table does not exist in the cached schema. + // + // Attempt to load an existing table from the catalog or create a new table in the + // catalog to populate the cache. + let table = + table_load_or_create(repos, schema.id, &schema.partition_template, table_name) + .await + .map_err(|e| TableScopedError(table_name.to_string(), e))?; + + assert!(schema + .to_mut() + .tables + .insert(table_name.to_string(), table) + .is_none()); + + Cow::Borrowed(schema.tables.get(table_name).unwrap()) + } + }; + + // The table is now in the schema (either by virtue of it already existing, + // or through adding it above). + // + // If the table itself needs to be updated during column validation it + // becomes a Cow::owned() copy and the modified copy should be inserted into + // the schema before returning. + validate_and_insert_columns( + mb.columns() + .map(|(name, col)| (name, col.influx_type().into())), + table_name, + &mut table, + repos, + ) + .await?; + + if let Cow::Owned(table) = table { + // The table schema was mutated and needs inserting into the namespace + // schema to make the changes visible to the caller. + assert!(schema + .to_mut() + .tables + .insert(table_name.to_string(), table) + .is_some()); + } + + Ok(()) +} + +/// Given an iterator of `(column_name, column_type)` to validate, this function ensures all the +/// columns match the existing `TableSchema` in `table`. If the column does not already exist in +/// `table`, it is created and the `table` is changed to the `Cow::Owned` variant. +/// +/// This function pushes schema additions through to the backend catalog, and relies on the catalog +/// to serialize concurrent additions of a given column, ensuring only one type is ever accepted +/// per column. +// &mut Cow is used to avoid a copy, so allow it +#[allow(clippy::ptr_arg)] +pub async fn validate_and_insert_columns( + columns: impl Iterator + Send, + table_name: &str, + table: &mut Cow<'_, TableSchema>, + repos: &mut R, +) -> Result<(), TableScopedError> +where + R: RepoCollection + ?Sized, +{ + let mut column_batch: HashMap<&str, ColumnType> = HashMap::new(); + + for (name, column_type) in columns { + // Check if the column exists in the cached schema. + // + // If it does, validate it. If it does not exist, create it and insert + // it into the cached schema. + + match table.columns.get(name.as_str()) { + Some(existing) if existing.column_type == column_type => { + // No action is needed as the column matches the existing column + // schema. + } + Some(existing) => { + // The column schema and the column in the schema change are of + // different types. + return Err(TableScopedError( + table_name.to_string(), + Error::AlreadyExists { + descr: format!( + "column {} is type {} but schema update has type {}", + name, existing.column_type, column_type + ), + }, + )); + } + None => { + // The column does not exist in the cache, add it to the column + // batch to be bulk inserted later. + let old = column_batch.insert(name.as_str(), column_type); + assert!( + old.is_none(), + "duplicate column name `{name}` in new column schema shouldn't be possible" + ); + } + } + } + + if !column_batch.is_empty() { + repos + .columns() + .create_or_get_many_unchecked(table.id, column_batch) + .await + .map_err(|e| TableScopedError(table_name.to_string(), e))? + .into_iter() + .for_each(|c| table.to_mut().add_column(c)); + } + + Ok(()) +} + +/// Load or create table. +pub async fn table_load_or_create( + repos: &mut R, + namespace_id: NamespaceId, + namespace_partition_template: &NamespacePartitionTemplateOverride, + table_name: &str, +) -> Result +where + R: RepoCollection + ?Sized, +{ + let table = match repos + .tables() + .get_by_namespace_and_name(namespace_id, table_name) + .await? + { + Some(table) => table, + None => { + // There is a possibility of a race condition here, if another request has also + // created this table after the `get_by_namespace_and_name` call but before + // this `create` call. In that (hopefully) rare case, do an additional fetch + // from the catalog for the record that should now exist. + let create_result = repos + .tables() + .create( + table_name, + // This table is being created implicitly by this write, so there's no + // possibility of a user-supplied partition template here, which is why there's + // a hardcoded `None`. If there is a namespace template, it must be valid because + // validity was checked during its creation, so that's why there's an `expect`. + TablePartitionTemplateOverride::try_new(None, namespace_partition_template) + .expect("no table partition template; namespace partition template has been validated"), + namespace_id, + ) + .await; + if let Err(Error::AlreadyExists { .. }) = create_result { + repos + .tables() + .get_by_namespace_and_name(namespace_id, table_name) + // Propagate any `Err` returned by the catalog + .await? + // Getting `Ok(None)` should be impossible if we're in this code path because + // the `create` request just said the table exists + .expect( + "Table creation failed because the table exists, so looking up the table \ + should return `Some(table)`, but it returned `None`", + ) + } else { + create_result? + } + } + }; + + let mut table = TableSchema::new_empty_from(&table); + + // Always add a time column to all new tables. + let time_col = repos + .columns() + .create_or_get(TIME_COLUMN, table.id, ColumnType::Time) + .await?; + + table.add_column(time_col); + + Ok(table) +} + +#[cfg(test)] +mod tests { + use std::{collections::BTreeMap, sync::Arc}; + + use super::*; + use crate::{interface::SoftDeletedRows, mem::MemCatalog, util::get_schema_by_name}; + + // Generate a test that simulates multiple, sequential writes in `lp` and + // asserts the resulting schema. + // + // This test asserts the cached schema and the database entry are always in + // sync. + macro_rules! test_validate_schema { + ( + $name:ident, + lp = [$($lp:literal,)+], // An array of multi-line LP writes + want_observe_conflict = $want_observe_conflict:literal, // true if a schema validation error should be observed at some point + want_schema = {$($want_schema:tt) +} // The expected resulting schema after all writes complete. + ) => { + paste::paste! { + #[allow(clippy::bool_assert_comparison)] + #[tokio::test] + async fn []() { + use crate::{interface::Catalog, test_helpers::arbitrary_namespace}; + use std::ops::DerefMut; + use pretty_assertions::assert_eq; + const NAMESPACE_NAME: &str = "bananas"; + + let metrics = Arc::new(metric::Registry::default()); + let time_provider = Arc::new(iox_time::SystemProvider::new()); + let repo = MemCatalog::new(metrics, time_provider); + let mut txn = repo.repositories(); + + let namespace = arbitrary_namespace(&mut *txn, NAMESPACE_NAME) + .await; + let schema = NamespaceSchema::new_empty_from(&namespace); + + // Apply all the lp literals as individual writes, feeding + // the result of one validation into the next to drive + // incremental construction of the schemas. + let mut observed_conflict = false; + $( + let schema = { + let lp: String = $lp.to_string(); + + let writes = mutable_batch_lp::lines_to_batches(lp.as_str(), 42) + .expect("failed to build test writes from LP"); + + let got = validate_or_insert_schema(writes.iter().map(|(k, v)| (k.as_str(), v)), &schema, txn.deref_mut()) + .await; + + match got { + Err(TableScopedError(_, Error::AlreadyExists{ .. })) => { + observed_conflict = true; + schema + }, + Err(e) => panic!("unexpected error: {}", e), + Ok(Some(new_schema)) => new_schema, + Ok(None) => schema, + } + }; + )+ + + assert_eq!($want_observe_conflict, observed_conflict, "should error mismatch"); + + // Invariant: in absence of concurrency, the schema within + // the database must always match the incrementally built + // cached schema. + let db_schema = get_schema_by_name(NAMESPACE_NAME, txn.deref_mut(), SoftDeletedRows::ExcludeDeleted) + .await + .expect("database failed to query for namespace schema") + .expect("namespace exists"); + assert_eq!(schema, db_schema, "schema in DB and cached schema differ"); + + // Generate the map of tables => desired column types + let want_tables: BTreeMap, ColumnType>> = test_validate_schema!(@table, $($want_schema)+); + + // Generate a similarly structured map from the actual + // schema + let actual_tables: BTreeMap, ColumnType>> = schema + .tables + .iter() + .map(|(table, table_schema)| { + let desired_cols = table_schema + .columns + .iter() + .map(|(column, column_schema)| (Arc::clone(&column), column_schema.column_type)) + .collect::>(); + + (table.clone(), desired_cols) + }) + .collect(); + + // Assert the actual namespace contents matches the desired + // table schemas in the test args. + assert_eq!(want_tables, actual_tables, "cached schema and desired schema differ"); + } + } + }; + // Generate a map of table names => column map (below) + // + // out: BTreeMap> + (@table, $($table_name:literal: [$($columns:tt) +],)*) => {{ + let mut tables = BTreeMap::new(); + $( + let want_cols = test_validate_schema!(@column, $($columns)+); + assert!(tables.insert($table_name.to_string(), want_cols).is_none()); + )* + tables + }}; + // Generate a map of column names => ColumnType + // + // out: BTreeMap + (@column, $($col_name:literal => $col_type:expr,)+) => {{ + let mut cols = BTreeMap::new(); + $( + assert!(cols.insert(Arc::from($col_name), $col_type).is_none()); + )* + cols + }}; + } + + test_validate_schema!( + one_write_multiple_tables, + lp = [ + " + m1,t1=a,t2=b f1=2i,f2=2.0 1\n\ + m1,t1=a f1=3i 2\n\ + m2,t3=b f1=true 1\n\ + ", + ], + want_observe_conflict = false, + want_schema = { + "m1": [ + "t1" => ColumnType::Tag, + "t2" => ColumnType::Tag, + "f1" => ColumnType::I64, + "f2" => ColumnType::F64, + "time" => ColumnType::Time, + ], + "m2": [ + "f1" => ColumnType::Bool, + "t3" => ColumnType::Tag, + "time" => ColumnType::Time, + ], + } + ); + + // test that a new table will be created + test_validate_schema!( + two_writes_incremental_new_table, + lp = [ + " + m1,t1=a,t2=b f1=2i,f2=2.0 1\n\ + m1,t1=a f1=3i 2\n\ + m2,t3=b f1=true 1\n\ + ", + " + m1,t1=c f1=1i 2\n\ + new_measurement,t9=a f10=true 1\n\ + ", + ], + want_observe_conflict = false, + want_schema = { + "m1": [ + "t1" => ColumnType::Tag, + "t2" => ColumnType::Tag, + "f1" => ColumnType::I64, + "f2" => ColumnType::F64, + "time" => ColumnType::Time, + ], + "m2": [ + "f1" => ColumnType::Bool, + "t3" => ColumnType::Tag, + "time" => ColumnType::Time, + ], + "new_measurement": [ + "t9" => ColumnType::Tag, + "f10" => ColumnType::Bool, + "time" => ColumnType::Time, + ], + } + ); + + // test that a new column for an existing table will be created + test_validate_schema!( + two_writes_incremental_new_column, + lp = [ + " + m1,t1=a,t2=b f1=2i,f2=2.0 1\n\ + m1,t1=a f1=3i 2\n\ + m2,t3=b f1=true 1\n\ + ", + "m1,new_tag=c new_field=1i 2", + ], + want_observe_conflict = false, + want_schema = { + "m1": [ + "t1" => ColumnType::Tag, + "t2" => ColumnType::Tag, + "f1" => ColumnType::I64, + "f2" => ColumnType::F64, + "time" => ColumnType::Time, + // These are the incremental additions: + "new_tag" => ColumnType::Tag, + "new_field" => ColumnType::I64, + ], + "m2": [ + "f1" => ColumnType::Bool, + "t3" => ColumnType::Tag, + "time" => ColumnType::Time, + ], + } + ); + + test_validate_schema!( + table_always_has_time_column, + lp = [ + "m1,t1=a f1=2i", + ], + want_observe_conflict = false, + want_schema = { + "m1": [ + "t1" => ColumnType::Tag, + "f1" => ColumnType::I64, + "time" => ColumnType::Time, + ], + } + ); + + test_validate_schema!( + two_writes_conflicting_column_types, + lp = [ + "m1,t1=a f1=2i", + // Second write has conflicting type for f1. + "m1,t1=a f1=2.0", + ], + want_observe_conflict = true, + want_schema = { + "m1": [ + "t1" => ColumnType::Tag, + "f1" => ColumnType::I64, + "time" => ColumnType::Time, + ], + } + ); + + test_validate_schema!( + two_writes_tag_field_transposition, + lp = [ + // x is a tag + "m1,t1=a,x=t f1=2i", + // x is a field + "m1,t1=a x=t,f1=2i", + ], + want_observe_conflict = true, + want_schema = { + "m1": [ + "t1" => ColumnType::Tag, + "x" => ColumnType::Tag, + "f1" => ColumnType::I64, + "time" => ColumnType::Time, + ], + } + ); + + #[tokio::test] + async fn validate_table_create_race_doesnt_get_all_columns() { + use crate::{interface::Catalog, test_helpers::arbitrary_namespace}; + use std::{collections::BTreeSet, ops::DerefMut}; + const NAMESPACE_NAME: &str = "bananas"; + + let repo = MemCatalog::new( + Default::default(), + Arc::new(iox_time::SystemProvider::new()), + ); + let mut txn = repo.repositories(); + let namespace = arbitrary_namespace(&mut *txn, NAMESPACE_NAME).await; + + // One cached schema has no tables. + let empty_schema = NamespaceSchema::new_empty_from(&namespace); + + // Another cached schema gets a write that creates a table with some columns. + let schema_with_table = empty_schema.clone(); + let writes = mutable_batch_lp::lines_to_batches("m1,t1=a f1=2i", 42).unwrap(); + validate_or_insert_schema( + writes.iter().map(|(k, v)| (k.as_str(), v)), + &schema_with_table, + txn.deref_mut(), + ) + .await + .unwrap(); + + // then the empty schema adds the same table with some different columns + let other_writes = mutable_batch_lp::lines_to_batches("m1,t2=a f2=2i", 43).unwrap(); + let formerly_empty_schema = validate_or_insert_schema( + other_writes.iter().map(|(k, v)| (k.as_str(), v)), + &empty_schema, + txn.deref_mut(), + ) + .await + .unwrap() + .unwrap(); + + // the formerly-empty schema should NOT have all the columns; schema convergence is handled + // at a higher level by the namespace cache/gossip system + let table = formerly_empty_schema.tables.get("m1").unwrap(); + assert_eq!(table.columns.names(), BTreeSet::from(["t2", "f2", "time"])); + } +} diff --git a/iox_data_generator/Cargo.toml b/iox_data_generator/Cargo.toml new file mode 100644 index 0000000..7289896 --- /dev/null +++ b/iox_data_generator/Cargo.toml @@ -0,0 +1,48 @@ +[package] +name = "iox_data_generator" +default-run = "iox_data_generator" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[lints] +workspace = true + +[dependencies] +bytes = "1.5" +chrono = { version = "0.4", default-features = false } +clap = { version = "4", features = ["derive", "env", "cargo"] } +datafusion_util = { path = "../datafusion_util" } +futures = "0.3" +handlebars = "5.1.0" +humantime = "2.1.0" +influxdb2_client = { path = "../influxdb2_client" } +itertools = "0.12.0" +mutable_batch_lp = { path = "../mutable_batch_lp" } +mutable_batch = { path = "../mutable_batch" } +parquet_file = { path = "../parquet_file" } +rand = { version = "0.8.3", features = ["small_rng"] } +regex = "1.10" +schema = { path = "../schema" } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0.111" +snafu = "0.8" +tokio = { version = "1.35", features = ["macros", "parking_lot", "rt-multi-thread", "sync", "time"] } +toml = "0.8.8" +tracing = "0.1" +tracing-subscriber = "0.3" +uuid = { version = "1", default_features = false } + +[dev-dependencies] +criterion = { version = "0.5", default-features = false, features = ["rayon"]} +test_helpers = { path = "../test_helpers" } + +[[bench]] +name = "point_generation" +harness = false + +[lib] +# Allow --save-baseline to work +# https://github.com/bheisler/criterion.rs/issues/275 +bench = false diff --git a/iox_data_generator/README.md b/iox_data_generator/README.md new file mode 100644 index 0000000..3bf275f --- /dev/null +++ b/iox_data_generator/README.md @@ -0,0 +1,19 @@ +# `iox_data_generator` + +The `iox_data_generator` tool creates random data points according to a specification and loads them +into an `iox` instance to simulate real data. + +To build and run, [first install Rust](https://www.rust-lang.org/tools/install). Then from root of the `influxdb_iox` repo run: + +``` +cargo build --release +``` + +And the built binary has command line help: + +``` +./target/release/iox_data_generator --help +``` + +For examples of specifications see the [schemas folder](schemas). The [full_example](schemas/full_example.toml) is the +most comprehensive with comments and example output. diff --git a/iox_data_generator/benches/point_generation.rs b/iox_data_generator/benches/point_generation.rs new file mode 100644 index 0000000..e29af90 --- /dev/null +++ b/iox_data_generator/benches/point_generation.rs @@ -0,0 +1,223 @@ +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; +use iox_data_generator::{ + agent::Agent, + specification::{ + AgentAssignmentSpec, AgentSpec, DataSpec, DatabaseWriterSpec, FieldSpec, FieldValueSpec, + MeasurementSpec, + }, + tag_set::GeneratedTagSets, + write::PointsWriterBuilder, +}; +use std::{ + sync::{atomic::AtomicU64, Arc}, + time::Duration, +}; + +pub fn single_agent(c: &mut Criterion) { + let spec = DataSpec { + name: "benchmark".into(), + values: vec![], + tag_sets: vec![], + agents: vec![AgentSpec { + name: "foo".to_string(), + measurements: vec![MeasurementSpec { + name: "measurement-1".into(), + count: None, + fields: vec![FieldSpec { + name: "field-1".into(), + field_value_spec: FieldValueSpec::Bool(true), + count: None, + }], + tag_set: None, + tag_pairs: vec![], + }], + has_one: vec![], + tag_pairs: vec![], + }], + database_writers: vec![DatabaseWriterSpec { + database_ratio: Some(1.0), + database_regex: None, + agents: vec![AgentAssignmentSpec { + name: "foo".to_string(), + count: None, + sampling_interval: "1s".to_string(), + }], + }], + }; + + let mut points_writer = PointsWriterBuilder::new_no_op(true); + + let start_datetime = Some(0); + let one_hour_s = 60 * 60; + let ns_per_second = 1_000_000_000; + let end_datetime = Some(one_hour_s * ns_per_second); + + let expected_points = 3601; + + let mut group = c.benchmark_group("single_agent"); + group.throughput(Throughput::Elements(expected_points)); + + group.bench_function("single agent with basic configuration", |b| { + b.iter(|| { + let r = block_on(iox_data_generator::generate( + &spec, + vec!["foo_bar".to_string()], + &mut points_writer, + start_datetime, + end_datetime, + 0, + false, + 1, + false, + )); + let n_points = r.expect("Could not generate data"); + assert_eq!(n_points, expected_points as usize); + }) + }); +} + +pub fn agent_pre_generated(c: &mut Criterion) { + let spec: DataSpec = toml::from_str( + r#" +name = "storage_cardinality_example" + +# Values are automatically generated before the agents are initialized. They generate tag key/value +# pairs with the name of the value as the tag key and the evaluated template as the value. These +# pairs are Arc wrapped so they can be shared across tagsets and used in the agents as +# pre-generated data. +[[values]] +# the name must not have a . in it, which is used to access children later. Otherwise it's open. +name = "role" +# the template can use a number of helpers to get an id, a random string and the name, see below +# for examples +template = "storage" +# this number of tag pairs will be generated. If this is > 1, the id or a random character string +# should be used in the template to ensure that the tag key/value pairs are unique. +cardinality = 1 + +[[values]] +name = "url" +template = "http://127.0.0.1:6060/metrics/usage" +cardinality = 1 + +[[values]] +name = "org_id" +# Fill in the value with the cardinality counter and 15 random alphanumeric characters +template = "{{id}}_{{random 15}}" +cardinality = 1000 +has_one = ["env"] + +[[values]] +name = "env" +template = "whatever-environment-{{id}}" +cardinality = 10 + +[[values]] +name = "bucket_id" +# a bucket belongs to an org. With this, you would be able to access the org.id or org.value in the +# template +belongs_to = "org_id" +# each bucket will have a unique id, which is used here to guarantee uniqueness even across orgs. +# We also have a random 15 character alphanumeric sequence to pad out the value length. +template = "{{id}}_{{random 15}}" +# For each org, 3 buckets will be generated +cardinality = 3 + +[[values]] +name = "partition_id" +template = "{{id}}" +cardinality = 10 + +# makes a tagset so every bucket appears in every partition. The other tags are descriptive and +# don't increase the cardinality beyond count(bucket) * count(partition). Later this example will +# use the agent and measurement generation to take this base tagset and increase cardinality on a +# per-agent basis. +[[tag_sets]] +name = "bucket_set" +for_each = [ + "role", + "url", + "org_id", + "org_id.env", + "org_id.bucket_id", + "partition_id", +] + +[[agents]] +name = "foo" + +[[agents.measurements]] +name = "storage_usage_bucket_cardinality" +# each sampling will have all the tag sets from this collection in addition to the tags and +# tag_pairs specified +tag_set = "bucket_set" +# for each agent, this specific measurement will be decorated with these additional tags. +tag_pairs = [ + {key = "node_id", template = "{{agent.id}}"}, + {key = "hostname", template = "{{agent.id}}"}, + {key = "host", template = "storage-{{agent.id}}"}, +] + +[[agents.measurements.fields]] +name = "gauge" +i64_range = [1, 8147240] + +[[database_writers]] +agents = [{name = "foo", sampling_interval = "1s", count = 3}] +"#, + ) + .unwrap(); + + let generated_tag_sets = GeneratedTagSets::from_spec(&spec).unwrap(); + + let mut points_writer = PointsWriterBuilder::new_no_op(true); + + let start_datetime = Some(0); + let one_hour_s = 60 * 60; + let ns_per_second = 1_000_000_000; + let end_datetime = Some(one_hour_s * ns_per_second); + + let mut agents = Agent::from_spec( + &spec.agents[0], + 3, + Duration::from_millis(10), + start_datetime, + end_datetime, + 0, + false, + &generated_tag_sets, + ) + .unwrap(); + let agent = agents.first_mut().unwrap(); + let expected_points = 30000; + + let counter = Arc::new(AtomicU64::new(0)); + let request_counter = Arc::new(AtomicU64::new(0)); + let mut group = c.benchmark_group("agent_pre_generated"); + group.measurement_time(std::time::Duration::from_secs(50)); + group.throughput(Throughput::Elements(expected_points)); + + group.bench_function("single agent with basic configuration", |b| { + b.iter(|| { + agent.reset_current_date_time(0); + let points_writer = + Arc::new(points_writer.build_for_agent("foo", "foo", "foo").unwrap()); + let r = block_on(agent.generate_all( + points_writer, + 1, + Arc::clone(&counter), + Arc::clone(&request_counter), + )); + let n_points = r.expect("Could not generate data"); + assert_eq!(n_points.row_count, expected_points as usize); + }) + }); +} + +#[tokio::main] +async fn block_on(f: F) -> F::Output { + f.await +} + +criterion_group!(benches, single_agent, agent_pre_generated); +criterion_main!(benches); diff --git a/iox_data_generator/schemas/big_db.toml b/iox_data_generator/schemas/big_db.toml new file mode 100644 index 0000000..73ca71d --- /dev/null +++ b/iox_data_generator/schemas/big_db.toml @@ -0,0 +1,143 @@ +# this schema is for testing what it looks like with a database that has +# hundreds of thousands of measurements with different levels of throughput. +# +# The high agent sends 10k lines with 500 measurements totaling 2.48 MB per sampling +# The medium agent sends 10k lines with 1k measurements totaling 2.14 MB per sampling +# The low agent sends 10k lines with 10k measurements and 1.45 MB per sampling +# +# Based on the database_writers at the bottom, this will write 225k total measurements +# across 50 separate agents writing once every 10s. Aggregate throughput is about +# 35.76 MB/sec of raw line protocol +name = "big_db" + +[[values]] +name = "some_tag_here" +cardinality = 10 +template = "value-{{id}}-{{random 5}}" + +[[values]] +name = "some_other_tag" +cardinality = 2 +template = "value-{{id}}-{{random 10}}" +belongs_to = "some_tag_here" + +[[values]] +name = "some_static_tag" +cardinality = 1 +template = "whatevs-is-something-we-have" + +[[tag_sets]] +name = "20card" +for_each = [ + "some_tag_here", + "some_tag_here.some_other_tag", + "some_static_tag", +] + +[[tag_sets]] +name = "10card" +for_each = [ + "some_tag_here", + "some_static_tag", +] + +[[tag_sets]] +name = "2card" +for_each = [ + "some_other_tag", + "some_static_tag", +] + +# generates data that looks like: +# +# high_measurement_10_card_500_1,agent_id=1,foo_bar=stuff-is-here-now,measurement_id=500,some_other_tag=value-17-0wyJ8VuUO7,some_static_tag=whatevs-is-something-we-have,some_tag_here=value-9-fuFo3 intfield=63976i,floatfield=0.6004810270043124 1639597814875290000 +# high_measurement_10_card_500_1,agent_id=1,foo_bar=stuff-is-here-now,measurement_id=500,some_other_tag=value-18-I9P4V97Kfm,some_static_tag=whatevs-is-something-we-have,some_tag_here=value-9-fuFo3 intfield=24564i,floatfield=0.11957361442062764 1639597814875290000 +# high_measurement_10_card_500_1,agent_id=1,foo_bar=stuff-is-here-now,measurement_id=500,some_other_tag=value-19-HaW3lHJ2le,some_static_tag=whatevs-is-something-we-have,some_tag_here=value-10-yH0Bj intfield=18157i,floatfield=0.10429525001385809 1639597814875290000 +# high_measurement_10_card_500_1,agent_id=1,foo_bar=stuff-is-here-now,measurement_id=500,some_other_tag=value-20-XOgmzSFzm7,some_static_tag=whatevs-is-something-we-have,some_tag_here=value-10-yH0Bj intfield=51041i,floatfield=0.802468465951919 1639597814875290000 +[[agents]] +name = "high" +tag_pairs = [ + {key = "agent_id", template = "{{agent.id}}"}, + {key = "foo_bar", template = "stuff-is-here-now"} +] + +[[agents.measurements]] +name = "high_measurement_10_card_{{measurement.id}}_{{agent.id}}" +count = 500 +tag_set = "20card" +tag_pairs = [ + {key = "measurement_id", template = "{{measurement.id}}"} +] + +[[agents.measurements.fields]] +name = "intfield" +i64_range = [1, 100000] + +[[agents.measurements.fields]] +name = "floatfield" +f64_range = [0.0, 1.0] + +# generates data that looks like: +# +# med_measurement_10_card_1000_1,agent_id=1,foo_bar=stuff-is-here-now,measurement_id=1000,some_static_tag=whatevs-is-something-we-have,some_tag_here=value-7-UhxFA intfield=24707i,floatfield=0.762661180672112 1639597855224165000 +# med_measurement_10_card_1000_1,agent_id=1,foo_bar=stuff-is-here-now,measurement_id=1000,some_static_tag=whatevs-is-something-we-have,some_tag_here=value-8-YzAUN intfield=94490i,floatfield=0.4309492192063673 1639597855224165000 +# med_measurement_10_card_1000_1,agent_id=1,foo_bar=stuff-is-here-now,measurement_id=1000,some_static_tag=whatevs-is-something-we-have,some_tag_here=value-9-vUmMN intfield=68817i,floatfield=0.9156455784544137 1639597855224165000 +# med_measurement_10_card_1000_1,agent_id=1,foo_bar=stuff-is-here-now,measurement_id=1000,some_static_tag=whatevs-is-something-we-have,some_tag_here=value-10-gxcic intfield=84220i,floatfield=0.9267974321691199 1639597855224165000 +[[agents]] +name = "medium" +tag_pairs = [ + {key = "agent_id", template = "{{agent.id}}"}, + {key = "foo_bar", template = "stuff-is-here-now"} +] + +[[agents.measurements]] +name = "med_measurement_10_card_{{measurement.id}}_{{agent.id}}" +count = 1000 +tag_set = "10card" +tag_pairs = [ + {key = "measurement_id", template = "{{measurement.id}}"} +] + +[[agents.measurements.fields]] +name = "intfield" +i64_range = [1, 100000] + +[[agents.measurements.fields]] +name = "floatfield" +f64_range = [0.0, 1.0] + +# generates data that looks like: +# +# low_measurement_2_card_4986_1,agent_id=1,foo_bar=stuff-is-here-now,measurement_id=4986 intfield=17484i,floatfield=0.5834872217437403 1639597582877742000 +# low_measurement_2_card_4987_1,agent_id=1,foo_bar=stuff-is-here-now,measurement_id=4987 intfield=83563i,floatfield=0.7354522843365716 1639597582877742000 +# low_measurement_2_card_4988_1,agent_id=1,foo_bar=stuff-is-here-now,measurement_id=4988 intfield=74676i,floatfield=0.7443686050113958 1639597582877742000 +# low_measurement_2_card_4989_1,agent_id=1,foo_bar=stuff-is-here-now,measurement_id=4989 intfield=69285i,floatfield=0.05047660569705048 1639597582877742000 +# low_measurement_2_card_4990_1,agent_id=1,foo_bar=stuff-is-here-now,measurement_id=4990 intfield=36686i,floatfield=0.7546950434825994 1639597582877742000 +[[agents]] +name = "low" +tag_pairs = [ + {key = "agent_id", template = "{{agent.id}}"}, + {key = "foo_bar", template = "stuff-is-here-now"} +] + +[[agents.measurements]] +name = "low_measurement_2_card_{{measurement.id}}_{{agent.id}}" +count = 10000 +tag_pairs = [ + {key = "measurement_id", template = "{{measurement.id}}"} +] + +[[agents.measurements.fields]] +name = "intfield" +i64_range = [1, 100000] + +[[agents.measurements.fields]] +name = "floatfield" +f64_range = [0.0, 1.0] + +[[database_writers]] +agents = [ + {name = "high", sampling_interval = "10s", count = 10}, # 5,000 measurements + {name = "medium", sampling_interval = "10s", count = 20}, # 20,000 measurements + {name = "low", sampling_interval = "10s", count = 20} # 200,000 measurements +] diff --git a/iox_data_generator/schemas/cap-write.toml b/iox_data_generator/schemas/cap-write.toml new file mode 100644 index 0000000..5b77a85 --- /dev/null +++ b/iox_data_generator/schemas/cap-write.toml @@ -0,0 +1,405 @@ +# This config file aims to replicate the data produced by the capwrite tool: +# https://github.com/influxdata/idpe/tree/e493a8e9b6b773e9374a8542ddcab7d8174d320d/performance/capacity/write +name = "cap_write" + +[[database_writers]] +database_ratio = 1.0 +agents = [{name = "telegraf", count = 3, sampling_interval = "10s"}] + +[[agents]] +name = "telegraf" +tag_pairs = [ + {key = "host", template = "host-{{agent.id}}"} +] + +[[agents.measurements]] +name = "system" + + [[agents.measurements.fields]] + name = "n_cpus" + i64_range = [8, 8] + + [[agents.measurements.fields]] + name = "n_users" + i64_range = [2, 11] + + [[agents.measurements.fields]] + name = "uptime" + uptime = "i64" + + [[agents.measurements.fields]] + name = "uptime_format" + uptime = "telegraf" + + [[agents.measurements.fields]] + name = "load1" + f64_range = [0.0, 8.0] + + [[agents.measurements.fields]] + name = "load5" + f64_range = [0.0, 8.0] + + [[agents.measurements.fields]] + name = "load15" + f64_range = [0.0, 8.0] + + +[[agents.measurements]] +name = "mem" + + [[agents.measurements.fields]] + name = "active" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "available" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "buffered" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "cached" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "free" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "inactive" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "slab" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "total" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "used" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "available_percent" + f64_range = [0.0, 100.0] + + [[agents.measurements.fields]] + name = "used_percent" + f64_range = [0.0, 100.0] + + [[agents.measurements.fields]] + name = "wired" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "commit_limit" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "committed_as" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "dirty" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "high_free" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "high_total" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "huge_page_size" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "huge_pages_free" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "huge_pages_total" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "low_free" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "low_total" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "mapped" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "page_tables" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "shared" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "swap_cached" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "swap_free" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "swap_total" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "vmalloc_chunk" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "vmalloc_total" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "vmalloc_used" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "write_back" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "write_back_tmp" + i64_range = [0, 10000000] + +[[agents.measurements]] +name = "disk" + + [[agents.measurements.fields]] + name = "free" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "total" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "used" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "used_percent" + f64_range = [0.0, 100.0] + + [[agents.measurements.fields]] + name = "inodes_free" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "inodes_total" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "inodes_used" + i64_range = [0, 10000000] + +[[agents.measurements]] +name = "swap" + + [[agents.measurements.fields]] + name = "free" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "total" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "used" + i64_range = [0, 1000000] # Note this is an order of magnitude less deliberately to match + # https://github.com/influxdata/idpe/blob/ffbceb04dd4b3aa0828d039135977a4f36f7b822/performance/capacity/write/swap.go#L17 + # not sure if that value was intentional, perhaps it is to ensure used < total? + + [[agents.measurements.fields]] + name = "used_percent" + f64_range = [0.0, 100.0] + + [[agents.measurements.fields]] + name = "in" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "out" + i64_range = [0, 10000000] + +[[agents.measurements]] +name = "cpu" +tag_pairs = [{key = "cpu", template = "cpu-total"}] + + [[agents.measurements.fields]] + name = "usage_user" + f64_range = [0.0, 100.0] + + [[agents.measurements.fields]] + name = "usage_nice" + f64_range = [0.0, 100.0] + + [[agents.measurements.fields]] + name = "usage_system" + f64_range = [0.0, 100.0] + + [[agents.measurements.fields]] + name = "usage_idle" + f64_range = [0.0, 100.0] + + [[agents.measurements.fields]] + name = "usage_irq" + f64_range = [0.0, 100.0] + + [[agents.measurements.fields]] + name = "usage_softirq" + f64_range = [0.0, 100.0] + + [[agents.measurements.fields]] + name = "usage_steal" + f64_range = [0.0, 100.0] + + [[agents.measurements.fields]] + name = "usage_guest" + f64_range = [0.0, 100.0] + + [[agents.measurements.fields]] + name = "usage_guest_nice" + f64_range = [0.0, 100.0] + +[[agents.measurements]] +name = "processes" + + [[agents.measurements.fields]] + name = "blocked" + i64_range = [0, 255] + + [[agents.measurements.fields]] + name = "running" + i64_range = [0, 255] + + [[agents.measurements.fields]] + name = "sleeping" + i64_range = [0, 255] + + [[agents.measurements.fields]] + name = "stopped" + i64_range = [0, 255] + + [[agents.measurements.fields]] + name = "total" + i64_range = [0, 255] + + [[agents.measurements.fields]] + name = "zombie" + i64_range = [0, 255] + + [[agents.measurements.fields]] + name = "dead" + i64_range = [0, 255] + + [[agents.measurements.fields]] + name = "wait" + i64_range = [0, 255] + + [[agents.measurements.fields]] + name = "idle" + i64_range = [0, 255] + + [[agents.measurements.fields]] + name = "paging" + i64_range = [0, 255] + + [[agents.measurements.fields]] + name = "total_threads" + i64_range = [0, 255] + + [[agents.measurements.fields]] + name = "unknown" + i64_range = [0, 255] + +[[agents.measurements]] +name = "net" + + [[agents.measurements.fields]] + name = "bytes_recv" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "bytes_sent" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "packets_sent" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "packets_recv" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "err_in" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "err_out" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "drop_in" + i64_range = [0, 10000000] + + [[agents.measurements.fields]] + name = "drop_out" + i64_range = [0, 10000000] + +[[agents.measurements]] +name = "diskio" + + [[agents.measurements.fields]] + name = "reads" + i64_range = [0, 1000000] + + [[agents.measurements.fields]] + name = "writes" + i64_range = [0, 1000000] + + [[agents.measurements.fields]] + name = "read_bytes" + i64_range = [0, 1000000] + + [[agents.measurements.fields]] + name = "write_bytes" + i64_range = [0, 1000000] + + [[agents.measurements.fields]] + name = "read_time" + i64_range = [0, 1000000] + + [[agents.measurements.fields]] + name = "write_time" + i64_range = [0, 1000000] + + [[agents.measurements.fields]] + name = "io_time" + i64_range = [0, 1000000] + + [[agents.measurements.fields]] + name = "weighted_io_time" + i64_range = [0, 1000000] + + [[agents.measurements.fields]] + name = "iops_in_progress" + i64_range = [0, 1000000] diff --git a/iox_data_generator/schemas/eu_central.toml b/iox_data_generator/schemas/eu_central.toml new file mode 100644 index 0000000..4d8d966 --- /dev/null +++ b/iox_data_generator/schemas/eu_central.toml @@ -0,0 +1,107 @@ +# generates load with 20k measurements getting a little bit of data, 20 measurements getting 300x the amount of data +# and 3 measurements that are very wide with 600 fields. Adjust the count or sampling interval of the three different +# agents to adjust how much load each type generates. But note that the first_agent sends far more lines per request +# which is how those measurements see so much more data. +name = "eu_central_sim" + +[[values]] +name = "some_tag" +cardinality = 10 +template = "id_{{id}}_{{random 15}}" +has_one = ["extra_static"] + +[[values]] +name = "child_tag" +cardinality = 10 +belongs_to = "some_tag" +has_one = ["rotation"] +template = "id_{{id}}_{{random 10}}" + +[[values]] +name = "rotation" +cardinality = 3 +template = "id_{{id}}_{{guid}}" + +[[values]] +name = "extra_static" +cardinality = 1 +template = "whatever-constant-value" + +[[tag_sets]] +name = "first_set" +for_each = [ + "some_tag", + "some_tag.extra_static", + "some_tag.child_tag", + "child_tag.rotation", +] + +[[tag_sets]] +name = "lower_cardinality_set" +for_each = [ + "some_tag", +] + +[[agents]] +name = "first_agent" +tag_pairs = [ + {key = "agent_id", template = "{{agent.id}}"} +] + +[[agents.measurements]] +name = "first_agent_measurement_{{measurement.id}}" +count = 20 +tag_set = "first_set" +tag_pairs = [ + {key = "measurement_id", template = "{{measurement.id}}"} +] + +[[agents.measurements.fields]] +name = "intfield" +i64_range = [1, 100000] + +[[agents.measurements.fields]] +name = "floatfield" +f64_range = [0.0, 1.0] + +[[agents]] +name = "second_agent" +tag_pairs = [ + {key = "agent_id", template = "second_agent_{{agent.id}}"} +] + +[[agents.measurements]] +name = "second_agent_measurement_{{measurement.id}}" +count = 20000 +tag_pairs = [ + {key = "measurement_id", template = "{{measurement.id}}"} +] + +[[agents.measurements.fields]] +name = "intfield" +i64_range = [1,1000] + +[[agents]] +name = "third_agent" +tag_pairs = [ + {key = "agent_id", template = "third_agent_{{agent.id}}"} +] + +[[agents.measurements]] +name = "third_agent_measurement_{{measurement.id}}" +count = 3 +tag_pairs = [ + {key = "measurement_id", template = "{{measurement.id}}"} +] + +[[agents.measurements.fields]] +name = "intfield_{{field.id}}" +count = 600 +i64_range = [1,1000] + +[[database_writers]] +agents = [ + {name = "first_agent", sampling_interval = "1s", count = 5}, + {name = "second_agent", sampling_interval = "1s", count = 40}, + {name = "third_agent", sampling_interval = "1s", count = 5}, +] diff --git a/iox_data_generator/schemas/full_example.toml b/iox_data_generator/schemas/full_example.toml new file mode 100644 index 0000000..34cc987 --- /dev/null +++ b/iox_data_generator/schemas/full_example.toml @@ -0,0 +1,188 @@ +# One run of the data generator output to --print will generate lines like this: + +# m1,agent_id=1,foo_bar=foo_8386ce4f-958d-42d5-9826-5dffc7e35ff9_1_DsZup_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_1,t1=t1_1,t2=t2_1_t1_1,t3=t3_1 intfield=48.31541353358504,intfield=63.16007209180341 1635968173847440000 +# m1,agent_id=1,foo_bar=foo_8386ce4f-958d-42d5-9826-5dffc7e35ff9_1_DsZup_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_2,t1=t1_1,t2=t2_1_t1_1,t3=t3_1 intfield=88.35678081075594,intfield=92.55272385943789 1635968173847440000 +# m1,agent_id=1,foo_bar=foo_8386ce4f-958d-42d5-9826-5dffc7e35ff9_1_DsZup_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_1,t1=t1_1,t2=t2_2_t1_1,t3=t3_1 intfield=71.34233494102085,intfield=19.35816384444733 1635968173847440000 +# m1,agent_id=1,foo_bar=foo_8386ce4f-958d-42d5-9826-5dffc7e35ff9_1_DsZup_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_2,t1=t1_1,t2=t2_2_t1_1,t3=t3_1 intfield=76.63378118605834,intfield=16.298451067775588 1635968173847440000 +# m1,agent_id=1,foo_bar=foo_8ffe4113-4680-43bd-9512-663bc84164f5_2_l8wdy_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_1,t1=t1_2,t2=t2_3_t1_2,t3=t3_2 intfield=96.71554665990536,intfield=93.44948263155631 1635968173847440000 +# m1,agent_id=1,foo_bar=foo_8ffe4113-4680-43bd-9512-663bc84164f5_2_l8wdy_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_2,t1=t1_2,t2=t2_3_t1_2,t3=t3_2 intfield=78.16527647371738,intfield=2.302033401489534 1635968173847440000 +# m1,agent_id=1,foo_bar=foo_8ffe4113-4680-43bd-9512-663bc84164f5_2_l8wdy_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_1,t1=t1_2,t2=t2_4_t1_2,t3=t3_2 intfield=90.37434758868368,intfield=7.552315135635346 1635968173847440000 +# m1,agent_id=1,foo_bar=foo_8ffe4113-4680-43bd-9512-663bc84164f5_2_l8wdy_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_2,t1=t1_2,t2=t2_4_t1_2,t3=t3_2 intfield=25.173607422073285,intfield=99.10021825896477 1635968173847440000 +# m1,agent_id=1,foo_bar=foo_8386ce4f-958d-42d5-9826-5dffc7e35ff9_1_DsZup_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_1,t1=t1_3,t2=t2_5_t1_3,t3=t3_1 intfield=31.724290085601936,intfield=71.04269945188204 1635968173847440000 +# m1,agent_id=1,foo_bar=foo_8386ce4f-958d-42d5-9826-5dffc7e35ff9_1_DsZup_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_2,t1=t1_3,t2=t2_5_t1_3,t3=t3_1 intfield=98.38837237131071,intfield=95.35495119280799 1635968173847440000 +# m1,agent_id=1,foo_bar=foo_8386ce4f-958d-42d5-9826-5dffc7e35ff9_1_DsZup_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_1,t1=t1_3,t2=t2_6_t1_3,t3=t3_1 intfield=15.860338450579835,intfield=20.932831216902017 1635968173847440000 +# m1,agent_id=1,foo_bar=foo_8386ce4f-958d-42d5-9826-5dffc7e35ff9_1_DsZup_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_2,t1=t1_3,t2=t2_6_t1_3,t3=t3_1 intfield=73.52354656855404,intfield=21.906048846128144 1635968173847440000 +# m2,agent_id=1,foo_bar=foo_8386ce4f-958d-42d5-9826-5dffc7e35ff9_1_DsZup_2021-11-03 i64field=1678687i,strfield="7bAK",uptime=0i,uptime_format="0 days, 00:00" 1635968173847440000 +# m2,agent_id=1,foo_bar=foo_8ffe4113-4680-43bd-9512-663bc84164f5_2_l8wdy_2021-11-03 i64field=7287348i,strfield="r2Xj",uptime=0i,uptime_format="0 days, 00:00" 1635968173847440000 +# m1,agent_id=2,foo_bar=foo_8386ce4f-958d-42d5-9826-5dffc7e35ff9_1_DsZup_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_1,t1=t1_1,t2=t2_1_t1_1,t3=t3_1 intfield=34.21564966893025,intfield=28.404777885873145 1635968173849823000 +# m1,agent_id=2,foo_bar=foo_8386ce4f-958d-42d5-9826-5dffc7e35ff9_1_DsZup_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_2,t1=t1_1,t2=t2_1_t1_1,t3=t3_1 intfield=89.53280753147736,intfield=88.35520078152399 1635968173849823000 +# m1,agent_id=2,foo_bar=foo_8386ce4f-958d-42d5-9826-5dffc7e35ff9_1_DsZup_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_1,t1=t1_1,t2=t2_2_t1_1,t3=t3_1 intfield=93.0798657117769,intfield=95.15086332651886 1635968173849823000 +# m1,agent_id=2,foo_bar=foo_8386ce4f-958d-42d5-9826-5dffc7e35ff9_1_DsZup_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_2,t1=t1_1,t2=t2_2_t1_1,t3=t3_1 intfield=16.383204148086563,intfield=69.36287104937198 1635968173849823000 +# m1,agent_id=2,foo_bar=foo_8ffe4113-4680-43bd-9512-663bc84164f5_2_l8wdy_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_1,t1=t1_2,t2=t2_3_t1_2,t3=t3_2 intfield=86.07310267461553,intfield=84.1837111118747 1635968173849823000 +# m1,agent_id=2,foo_bar=foo_8ffe4113-4680-43bd-9512-663bc84164f5_2_l8wdy_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_2,t1=t1_2,t2=t2_3_t1_2,t3=t3_2 intfield=66.97292091697567,intfield=13.792714677819795 1635968173849823000 +# m1,agent_id=2,foo_bar=foo_8ffe4113-4680-43bd-9512-663bc84164f5_2_l8wdy_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_1,t1=t1_2,t2=t2_4_t1_2,t3=t3_2 intfield=41.66956499741617,intfield=60.54778655915278 1635968173849823000 +# m1,agent_id=2,foo_bar=foo_8ffe4113-4680-43bd-9512-663bc84164f5_2_l8wdy_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_2,t1=t1_2,t2=t2_4_t1_2,t3=t3_2 intfield=50.85432735762039,intfield=51.71473345880968 1635968173849823000 +# m1,agent_id=2,foo_bar=foo_8386ce4f-958d-42d5-9826-5dffc7e35ff9_1_DsZup_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_1,t1=t1_3,t2=t2_5_t1_3,t3=t3_1 intfield=35.488387176278735,intfield=40.69930728826883 1635968173849823000 +# m1,agent_id=2,foo_bar=foo_8386ce4f-958d-42d5-9826-5dffc7e35ff9_1_DsZup_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_2,t1=t1_3,t2=t2_5_t1_3,t3=t3_1 intfield=52.224104265522485,intfield=17.630042482636732 1635968173849823000 +# m1,agent_id=2,foo_bar=foo_8386ce4f-958d-42d5-9826-5dffc7e35ff9_1_DsZup_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_1,t1=t1_3,t2=t2_6_t1_3,t3=t3_1 intfield=37.061044012796174,intfield=71.24055048796617 1635968173849823000 +# m1,agent_id=2,foo_bar=foo_8386ce4f-958d-42d5-9826-5dffc7e35ff9_1_DsZup_2021-11-03,m1tag=1-1,m1tag2=1-2,other=o_2,t1=t1_3,t2=t2_6_t1_3,t3=t3_1 intfield=31.513973186770073,intfield=61.978547758411295 1635968173849823000 +# m2,agent_id=2,foo_bar=foo_8386ce4f-958d-42d5-9826-5dffc7e35ff9_1_DsZup_2021-11-03 i64field=3258846i,strfield="j9PD",uptime=0i,uptime_format="0 days, 00:00" 1635968173849823000 +# m2,agent_id=2,foo_bar=foo_8ffe4113-4680-43bd-9512-663bc84164f5_2_l8wdy_2021-11-03 i64field=4426192i,strfield="RBhn",uptime=0i,uptime_format="0 days, 00:00" 1635968173849823000 +# new_agent_measurement-1,agent-name=another_example f1=f 1635968398968189000 +# new_agent_measurement-2,agent-name=another_example f1=f 1635968398968189000 + +name = "full_example" + +# Values are automatically generated before the agents are initialized. They generate tag key/value pairs +# with the name of the value as the tag key and the evaluated template as the value. These pairs +# can be shared across tagsets and used in the agents as pre-generated data. +[[values]] +# The name appears as the tag key in generated tag pairs and is used later when specifying tag_sets. +# It must not have a . in it, which is used to access children later. +name = "foo_bar" +# This number of tag pairs will be generated. If this is > 1, the id or a random character string should be +# used in the template to ensure that the tag key/value pairs are unique. +cardinality = 2 +# The template will be evaluated for each tag pair (so N times where n == cardinality) +# the template can use a number of helpers, which are shown in this example. +# guid - generate a guid +# id - the id of the tag pair. ids start at 0 +# random - generate random character string of the passed in length +# format-time - strftime formatted string for now +template = "foo_{{guid}}_{{id}}_{{random 5}}_{{format-time \"%Y-%m-%d\"}}" + +[[values]] +name = "t1" +template = "t1_{{id}}" +cardinality = 3 +# each t1 generated will reference one of t3 and one of foo_bar. As each t1 is generated +# it will loop through the t3 and foo_bar collections. So the 3rd t1 that is generated will +# reference the first t3 and foo_bar +has_one = ["t3", "foo_bar"] + +[[values]] +name = "t2" +# note that in this template we can access the parent's id because of the belongs_to +template = "t2_{{id}}_t1_{{t1.id}}" +cardinality = 2 +belongs_to = "t1" + +[[values]] +name = "t3" +template = "t3_{{id}}" +cardinality = 2 + +[[values]] +name = "other" +template = "o_{{id}}" +cardinality = 2 + +# tag_sets can be used later in the measurement specification. Each measurement can use one +# tag set (or none). For each sampling that is generated, each measurement will have lines +# generated equal to the cardinality of the tagset. +[[tag_sets]] +name = "example" +# for_each specifies how to iterate through the values to generate tagsets. If you want to +# use values that belong_to others or are a has_one, specify their parent first. For values +# without relationships, you'll get a combined cardinality of each multiplied by the other. +# In this example we get cardinality of card(t1) * card(foo_bar) * card(other). The has_one +# members of t1 don't increase cardinality. +for_each = [ + "t1", + "t1.t3", + "t1.foo_bar", + "t1.t2", + "other", +] + +[[tag_sets]] +name = "foos" +# note here that we have a tag set of foo_bar tag pairs. Values can be used outside the +# context of where they may be referenced in belong_to or has_one +for_each = [ + "foo_bar" +] + +# Agent specs can be referenced later on by bucket writers, which specify how frequently +# data should be written and by how many different agents. +[[agents]] +name = "first_agent" +# if specifying tag_pairs at the agent level, every line that the agent generates will have these +# tag pairs added to it. Note that the template has the same helpers as those in value (except for id). +# In addition, it has an accessor for the agent id. +tag_pairs = [ + {key = "agent_id", template = "{{agent.id}}"} +] + +[[agents.measurements]] +name = "m1" +# each sampling will have all the tag sets from this collection in addition to the tags and tag_pairs specified +tag_set = "example" +# for each agent, this specific measurement will be decorated with these additional tags. All the previous +# template helpers are available including now `measurement.id` and `tag.id`. +# This example also shows how to automatically generate many tags using `count` and how to specify +# that the tag value template should be re-evaluated after N number of lines. This N is counted across +# samplings. +tag_pairs = [ + {key = "m1tag", template = "{{measurement.id}}-{{id}}", count = 2, regenerate_after_lines = 5} +] + +# field values are generated on every line as they're written out +[[agents.measurements.fields]] +name = "intfield" +# Count is optional, we can use it to automatically create many fields +count = 2 +f64_range = [0.0, 100.00] + +[[agents.measurements]] +name = "m2" +tag_set = "foos" + +[[agents.measurements.fields]] +name = "i64field" +i64_range = [1, 8147240] + +[[agents.measurements.fields]] +name = "strfield" +template = "{{random 4}}" + +# this generates an int value representing how long the agent has been running +[[agents.measurements.fields]] +name = "uptime" +uptime = "i64" + +# generates uptime as a string value +[[agents.measurements.fields]] +name = "uptime_format" +uptime = "telegraf" + +[[agents]] +name = "another_example" +tag_pairs = [{key = "agent_name", template = "agent.name"}] + +[[agents.measurements]] +name = "new_agent_measurement-{{measurement.id}}" +# you can automatically generate many measurements with the same schema +count = 2 + +[[agents.measurements.fields]] +name = "f1" +bool = true + +# database_writers specify how to split up the list of supplied buckets to write to. If +# only a single one is specified via the CLI flags, then you'd want only a single bucket_writer +# with a percent of 1.0. +# +# These make it possible to split up a large list of buckets to write to and send different +# amounts of write load as well as different schemas through specifying different agents. +[[database_writers]] +# the first 20% of the databases specified in the --bucket_list file will have these agents writing to them +database_ratio = 0.2 +# for each of those databases, have 3 of the first_agent writing every 10s, and 1 of the another_example writing every minute. +agents = [{name = "first_agent", count = 3, sampling_interval = "10s"}, {name = "another_example", sampling_interval = "1m"}] + +[[database_writers]] +# the remaining 80% of the databases specified will write using these agents +database_ratio = 0.8 +# we'll only have a single agent of another_example for each database +agents = [{name = "another_example", sampling_interval = "1s"}] diff --git a/iox_data_generator/schemas/many_dbs.toml b/iox_data_generator/schemas/many_dbs.toml new file mode 100644 index 0000000..8f160e7 --- /dev/null +++ b/iox_data_generator/schemas/many_dbs.toml @@ -0,0 +1,80 @@ +# This schema is meant to test out many databases writing data in like a bunch of free tier users. +# Start with a database_list of 10k to make things interesting. This will send on average of +# 208 requests/sec and 6.25 MB/sec across the 10k databases. The top 10 will have 60 requests/min +# and 1.8MB/min. +name = "many_dbs" + +[[values]] +name = "some_tag_10" +cardinality = 2 +template = "id_{{id}}_{{random 15}}" +has_one = ["extra_static"] + +[[values]] +name = "child_tag" +cardinality = 3 +belongs_to = "some_tag_10" +has_one = ["rotation"] +template = "id_{{id}}_{{random 10}}" + +[[values]] +name = "rotation" +cardinality = 4 +template = "id_{{id}}_{{guid}}" + +[[values]] +name = "extra_static" +cardinality = 1 +template = "whatever-constant-value" + +[[tag_sets]] +name = "first_set" +for_each = [ + "some_tag_10", + "some_tag_10.extra_static", + "some_tag_10.child_tag", + "child_tag.rotation", +] + +# each sampling from this agent generates 32,465 bytes of LP, first few lines look like: +# main_measurement_1,agent_id=1,child_tag=id_1_T6iJnnBTE3,extra_static=whatever-constant-value,measurement_tag=1,rotation=id_1_de4ddb8c-31a6-440f-a273-7132bdd43bd7,some_tag_10=id_1_rWtIkI26LTlfu0J intfield=71334i,floatfield=0.7934452557768101 1639151629935287000 +# main_measurement_1,agent_id=1,child_tag=id_2_VsiUF2xVuz,extra_static=whatever-constant-value,measurement_tag=1,rotation=id_2_890145b3-8157-4d6f-ac02-1fe37584190f,some_tag_10=id_1_rWtIkI26LTlfu0J intfield=64582i,floatfield=0.0957134480635704 1639151629935287000 +# main_measurement_1,agent_id=1,child_tag=id_3_XNL51f1NdT,extra_static=whatever-constant-value,measurement_tag=1,rotation=id_3_8bcf7547-06e9-4033-9ffb-e00ac4e6c5a9,some_tag_10=id_1_rWtIkI26LTlfu0J intfield=26179i,floatfield=0.09993902612184669 1639151629935287000 +# main_measurement_1,agent_id=1,child_tag=id_4_mqCyprcTDQ,extra_static=whatever-constant-value,measurement_tag=1,rotation=id_4_f465d43e-f1ab-4250-99ac-67af7c1d4c72,some_tag_10=id_2_X4eWjH9ImjTeta2 intfield=16511i,floatfield=0.033060266070114475 1639151629935287000 +[[agents]] +name = "first_agent" +tag_pairs = [ + {key = "agent_id", template = "{{agent.id}}"} +] + +[[agents.measurements]] +name = "main_measurement_{{measurement.id}}" +count = 20 +tag_set = "first_set" +tag_pairs = [ + {key = "measurement_tag", template = "{{measurement.id}}"} +] + +[[agents.measurements.fields]] +name = "intfield" +i64_range = [1, 100000] + +[[agents.measurements.fields]] +name = "floatfield" +f64_range = [0.0, 1.0] + +[[database_writers]] +database_ratio = 0.001 +agents = [{name = "first_agent", sampling_interval = "1s"}] + +[[database_writers]] +database_ratio = 0.01 +agents = [{name = "first_agent", sampling_interval = "10s"}] + +[[database_writers]] +database_ratio = 0.1 +agents = [{name = "first_agent", sampling_interval = "30s"}] + +[[database_writers]] +database_ratio = 1.0 +agents = [{name = "first_agent", sampling_interval = "60s"}] \ No newline at end of file diff --git a/iox_data_generator/schemas/many_measurements.toml b/iox_data_generator/schemas/many_measurements.toml new file mode 100644 index 0000000..66bd27d --- /dev/null +++ b/iox_data_generator/schemas/many_measurements.toml @@ -0,0 +1,61 @@ +# This schema tests what load looks like with many measurements (2,000). If pointed at a single database +# with the configured 20 agents at 10s sampling, it will send an average of 2 requests/second (representing +# 16k rows) with 4.1MB/second of LP being written. Each agent writes 8k lines per request. +name = "many_measurements" + +[[values]] +name = "some_tag" +cardinality = 2 +template = "id_{{id}}_{{random 15}}" +has_one = ["extra_static"] + +[[values]] +name = "child_tag" +cardinality = 2 +belongs_to = "some_tag" +has_one = ["rotation"] +template = "id_{{id}}_{{random 10}}" + +[[values]] +name = "rotation" +cardinality = 3 +template = "id_{{id}}_{{guid}}" + +[[values]] +name = "extra_static" +cardinality = 1 +template = "whatever-constant-value" + +[[tag_sets]] +name = "first_set" +for_each = [ + "some_tag", + "some_tag.extra_static", + "some_tag.child_tag", + "child_tag.rotation", +] + +[[agents]] +name = "first_agent" +tag_pairs = [ + {key = "agent_id", template = "{{agent.id}}"} +] + +[[agents.measurements]] +name = "main_measurement_{{measurement.id}}" +count = 2000 +tag_set = "first_set" +tag_pairs = [ + {key = "measurement_id", template = "{{measurement.id}}"} +] + +[[agents.measurements.fields]] +name = "intfield" +i64_range = [1, 100000] + +[[agents.measurements.fields]] +name = "floatfield" +f64_range = [0.0, 1.0] + +[[database_writers]] +agents = [{name = "first_agent", sampling_interval = "10s", count = 20}] diff --git a/iox_data_generator/schemas/storage_cardinality_example.toml b/iox_data_generator/schemas/storage_cardinality_example.toml new file mode 100644 index 0000000..15b9707 --- /dev/null +++ b/iox_data_generator/schemas/storage_cardinality_example.toml @@ -0,0 +1,81 @@ +name = "storage_cardinality_example" + +# Values are automatically generated before the agents are initialized. They generate tag key/value pairs +# with the name of the value as the tag key and the evaluated template as the value. These pairs +# are Arc wrapped so they can be shared across tagsets and used in the agents as pre-generated data. +[[values]] +# the name must not have a . in it, which is used to access children later. Otherwise it's open. +name = "role" +# the template can use a number of helpers to get an id, a random string and the name, see below for examples +template = "storage" +# this number of tag pairs will be generated. If this is > 1, the id or a random character string should be +# used in the template to ensure that the tag key/value pairs are unique. +cardinality = 1 + +[[values]] +name = "url" +template = "http://127.0.0.1:6060/metrics/usage" +cardinality = 1 + +[[values]] +name = "org_id" +# Fill in the value with the cardinality counter and 15 random alphanumeric characters +template = "{{id}}_{{random 15}}" +cardinality = 100 +has_one = ["env"] + +[[values]] +name = "env" +template = "whatever-environment-{{id}}" +cardinality = 2 + +[[values]] +name = "bucket_id" +# a bucket belongs to an org. With this, you would be able to access the org.id or org.value in the template +belongs_to = "org_id" +# each bucket will have a unique id, which is used here to guarantee uniqueness even across orgs. We also +# have a random 15 character alphanumeric sequence to pad out the value length. +template = "{{id}}_{{random 15}}" +# For each org, 3 buckets will be generated +cardinality = 3 + +[[values]] +name = "partition_id" +template = "{{id}}" +cardinality = 10 + +# makes a tagset so every bucket appears in every partition. The other tags are descriptive and don't +# increase the cardinality beyond count(bucket) * count(partition). Later this example will use the +# agent and measurement generation to take this base tagset and increase cardinality on a per-agent basis. +[[tag_sets]] +name = "bucket_set" +for_each = [ + "role", + "url", + "org_id", + "org_id.env", + "org_id.bucket_id", + "partition_id", +] + +[[database_writers]] +database_ratio = 1.0 +agents = [{name = "sender", sampling_interval = "10s"}] + +[[agents]] +name = "sender" + +[[agents.measurements]] +name = "storage_usage_bucket_cardinality" +# each sampling will have all the tag sets from this collection in addition to the tags and tag_pairs specified +tag_set = "bucket_set" +# for each agent, this specific measurement will be decorated with these additional tags. +tag_pairs = [ + {key = "node_id", template = "{{agent.id}}"}, + {key = "hostname", template = "{{agent.id}}"}, + {key = "host", template = "storage-{{agent.id}}"}, +] + +[[agents.measurements.fields]] +name = "gauge" +i64_range = [1, 8147240] diff --git a/iox_data_generator/schemas/tracing-spec.toml b/iox_data_generator/schemas/tracing-spec.toml new file mode 100644 index 0000000..1df620d --- /dev/null +++ b/iox_data_generator/schemas/tracing-spec.toml @@ -0,0 +1,35 @@ +name = "tracing_schema" + +[[values]] +name = "host" +template = "server-{{id}}" +cardinality = 3000 +has_one = ["service"] + +[[values]] +name = "service" +template = "service-{{id}}" +cardinality = 10 + +[[tag_sets]] +name = "host_services" +for_each = ["host", "host.service"] + +[[agents]] +name = "tracing_agent" + +[[agents.measurements]] +name = "traces" +tag_set = "host_services" +tag_pairs = [ + {key = "trace_id", template = "{{guid}}", regenerate_after_lines = 10}, + {key = "span_id", template = "{{guid}}", regenerate_after_lines = 1}, +] + +[[agents.measurements.fields]] +name = "timing" +f64_range = [0.0, 500.0] + +[[database_writers]] +database_ratio = 1.0 +agents = [{name = "tracing_agent", sampling_interval = "1s"}] diff --git a/iox_data_generator/src/agent.rs b/iox_data_generator/src/agent.rs new file mode 100644 index 0000000..aeed284 --- /dev/null +++ b/iox_data_generator/src/agent.rs @@ -0,0 +1,692 @@ +//! Agents responsible for generating points + +use crate::{ + measurement::{MeasurementGenerator, MeasurementLineIterator}, + now_ns, specification, + tag_pair::TagPair, + write::PointsWriter, +}; + +use crate::tag_set::GeneratedTagSets; +use serde_json::json; +use snafu::{ResultExt, Snafu}; +use std::sync::{ + atomic::{AtomicU64, Ordering}, + Arc, +}; +use std::time::{Duration, Instant}; +use tracing::debug; + +/// Agent-specific Results +pub type Result = std::result::Result; + +/// Errors that may happen while creating points +#[derive(Snafu, Debug)] +#[allow(missing_docs)] +pub enum Error { + #[snafu(display("{}", source))] + CouldNotGeneratePoint { + /// Underlying `measurement` module error that caused this problem + source: crate::measurement::Error, + }, + + #[snafu(display("Could not create measurement generators, caused by:\n{}", source))] + CouldNotCreateMeasurementGenerators { + /// Underlying `measurement` module error that caused this problem + source: crate::measurement::Error, + }, + + #[snafu(display("Could not write points, caused by:\n{}", source))] + CouldNotWritePoints { + /// Underlying `write` module error that caused this problem + source: crate::write::Error, + }, + + #[snafu(display("Error creating agent tag pairs: {}", source))] + CouldNotCreateAgentTagPairs { source: crate::tag_pair::Error }, +} + +/// Each `AgentSpec` informs the instantiation of an `Agent`, which coordinates +/// the generation of the measurements in their specification. +#[derive(Debug)] +pub struct Agent { + /// identifier for the agent. This can be used in generated tags and fields + pub id: usize, + /// name for the agent. This can be used in generated tags and fields + pub name: String, + measurement_generators: Vec, + sampling_interval: Option, + /// nanoseconds since the epoch, used as the timestamp for the next + /// generated point + current_datetime: i64, + /// nanoseconds since the epoch, when current_datetime exceeds this, stop + /// generating points + end_datetime: i64, + /// whether to continue generating points after reaching the current time + continue_on: bool, + /// whether this agent is done generating points or not + finished: bool, + /// Optional interval at which to re-run the agent if generating data in + /// "continue" mode + interval: Option, +} + +/// Basic stats for agents generating requests +#[derive(Debug, Default, Copy, Clone)] +pub struct AgentGenerateStats { + /// number of rows the agent has written + pub row_count: usize, + /// number of requests the agent has made + pub request_count: usize, + /// number of errors + pub error_count: usize, +} + +impl AgentGenerateStats { + /// Display output for agent writing stats + pub fn display_stats(&self, elapsed_time: Duration) -> String { + if elapsed_time.as_secs() == 0 { + format!( + "made {} requests with {} rows in {:?} with {} errors for a {:.2} error rate", + self.request_count, + self.row_count, + elapsed_time, + self.error_count, + self.error_rate() + ) + } else { + let req_secs = elapsed_time.as_secs(); + let rows_per_sec = self.row_count as u64 / req_secs; + let reqs_per_sec = self.request_count as u64 / req_secs; + format!("made {} requests at {}/sec with {} rows at {}/sec in {:?} with {} errors for a {:.2} error rate", + self.request_count, reqs_per_sec, self.row_count, rows_per_sec, elapsed_time, self.error_count, self.error_rate()) + } + } + + fn error_rate(&self) -> f64 { + if self.error_count == 0 { + return 0.0; + } + self.error_count as f64 / self.request_count as f64 * 100.0 + } +} + +impl Agent { + /// Create agents that will generate data points according to these + /// specs. + #[allow(clippy::too_many_arguments)] + pub fn from_spec( + agent_spec: &specification::AgentSpec, + count: usize, + sampling_interval: Duration, + start_datetime: Option, // in nanoseconds since the epoch, defaults to now + end_datetime: Option, // also in nanoseconds since the epoch, defaults to now + execution_start_time: i64, + continue_on: bool, // If true, run in "continue" mode after historical data is generated + generated_tag_sets: &GeneratedTagSets, + ) -> Result> { + let agents: Vec<_> = (1..count + 1) + .map(|agent_id| { + let data = json!({"agent": {"id": agent_id, "name": agent_spec.name}}); + + let agent_tag_pairs = TagPair::pairs_from_specs(&agent_spec.tag_pairs, data) + .context(CouldNotCreateAgentTagPairsSnafu)?; + + let measurement_generators = agent_spec + .measurements + .iter() + .map(|spec| { + MeasurementGenerator::from_spec( + agent_id, + spec, + execution_start_time, + generated_tag_sets, + &agent_tag_pairs, + ) + .context(CouldNotCreateMeasurementGeneratorsSnafu) + }) + .collect::>>()?; + let measurement_generators = measurement_generators.into_iter().flatten().collect(); + + let current_datetime = start_datetime.unwrap_or_else(now_ns); + let end_datetime = end_datetime.unwrap_or_else(now_ns); + + Ok(Self { + id: agent_id, + name: agent_spec.name.to_string(), + measurement_generators, + sampling_interval: Some(sampling_interval), + current_datetime, + end_datetime, + continue_on, + finished: false, + interval: None, + }) + }) + .collect::>>()?; + + Ok(agents) + } + + /// Generate and write points in batches until `generate` doesn't return any + /// points. Points will be written to the writer in batches where `generate` is + /// called `batch_size` times before writing. Meant to be called in a `tokio::task`. + pub async fn generate_all( + &mut self, + points_writer: Arc, + batch_size: usize, + counter: Arc, + request_counter: Arc, + ) -> Result { + let mut points_this_batch = 1; + let start = Instant::now(); + let mut stats = AgentGenerateStats::default(); + + while points_this_batch != 0 { + let batch_start = Instant::now(); + points_this_batch = 0; + + let mut streams = Vec::with_capacity(batch_size); + for _ in 0..batch_size { + if self.finished { + break; + } else { + let mut s = self.generate().await?; + streams.append(&mut s); + } + } + + for s in &streams { + points_this_batch += s.line_count(); + } + + if points_this_batch == 0 && self.finished { + break; + } + + stats.request_count += 1; + match points_writer + .write_points(streams.into_iter().flatten()) + .await + .context(CouldNotWritePointsSnafu) + { + Ok(_) => { + stats.row_count += points_this_batch; + + if stats.request_count % 10 == 0 { + println!( + "Agent {} wrote {} in {:?}", + self.id, + points_this_batch, + batch_start.elapsed() + ); + } + + // output something on the aggregate stats every 100 requests across all agents + let total_rows = counter.fetch_add(points_this_batch as u64, Ordering::SeqCst); + let total_requests = request_counter.fetch_add(1, Ordering::SeqCst); + + if total_requests % 100 == 0 { + let secs = start.elapsed().as_secs(); + if secs != 0 { + println!( + "{} rows written in {} requests for {} rows/sec and {} reqs/sec", + total_rows, + total_requests, + total_rows / secs, + total_requests / secs, + ) + } + } + } + Err(e) => { + eprintln!("Error writing points: {e}"); + stats.error_count += 1; + } + } + } + + Ok(stats) + } + + /// Generate data points from the configuration in this agent. + pub async fn generate(&mut self) -> Result> { + debug!( + "[agent {}] finished? {} current: {}, end: {}", + self.id, self.finished, self.current_datetime, self.end_datetime + ); + + if !self.finished { + let mut measurement_streams = Vec::with_capacity(self.measurement_generators.len()); + + // Save the current_datetime to use in the set of points that we're generating + // because we might increment current_datetime to see if we're done + // or not. + let point_timestamp = self.current_datetime; + + if let Some(i) = &mut self.interval { + i.tick().await; + self.current_datetime = now_ns(); + } else if let Some(sampling_interval) = self.sampling_interval { + self.current_datetime += sampling_interval.as_nanos() as i64; + + if self.current_datetime > self.end_datetime { + if self.continue_on { + let mut i = tokio::time::interval(sampling_interval); + i.tick().await; // first tick completes immediately + self.current_datetime = now_ns(); + self.interval = Some(i); + } else { + self.finished = true; + } + } + } else { + self.finished = true; + } + + for mgs in &mut self.measurement_generators { + measurement_streams.push( + mgs.generate(point_timestamp) + .context(CouldNotGeneratePointSnafu)?, + ); + } + + Ok(measurement_streams) + } else { + Ok(Vec::new()) + } + } + + /// Sets the current date and time for the agent and resets its finished state to false. Enables + /// calling generate again during testing and benchmarking. + pub fn reset_current_date_time(&mut self, current_datetime: i64) { + self.finished = false; + self.current_datetime = current_datetime; + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::measurement::LineToGenerate; + use crate::{now_ns, specification::*}; + use influxdb2_client::models::WriteDataPoint; + + type Error = Box; + type Result = std::result::Result; + + impl Agent { + /// Instantiate an agent only with the parameters we're interested in + /// testing, keeping everything else constant across different + /// tests. + fn test_instance( + sampling_interval: Option, + continue_on: bool, + current_datetime: i64, + end_datetime: i64, + ) -> Self { + let measurement_spec = MeasurementSpec { + name: "measurement-{{agent.id}}-{{measurement.id}}".into(), + count: Some(2), + fields: vec![FieldSpec { + name: "field-{{agent.id}}-{{measurement.id}}-{{field.id}}".into(), + field_value_spec: FieldValueSpec::I64 { + range: 0..60, + increment: false, + reset_after: None, + }, + count: Some(2), + }], + tag_pairs: vec![], + tag_set: None, + }; + + let generated_tag_sets = GeneratedTagSets::default(); + + let measurement_generators = MeasurementGenerator::from_spec( + 1, + &measurement_spec, + current_datetime, + &generated_tag_sets, + &[], + ) + .unwrap(); + + Self { + id: 0, + name: "foo".to_string(), + finished: false, + interval: None, + + sampling_interval, + current_datetime, + end_datetime, + continue_on, + measurement_generators, + } + } + } + + fn timestamps(points: &[LineToGenerate]) -> Result> { + points + .iter() + .map(|point| { + let mut v = Vec::new(); + point.write_data_point_to(&mut v)?; + let line = String::from_utf8(v)?; + + Ok(line.split(' ').last().unwrap().trim().parse()?) + }) + .collect() + } + + #[rustfmt::skip] + // # Summary: No Sampling Interval + // + // If there isn't a sampling interval, we don't know how often to run, so we can neither + // generate historical data nor can we continue into the future. The only thing we'll do is + // generate once then stop. + // + // | sampling_interval | continue | cmp(current_time, end_time) | expected outcome | + // |-------------------+----------+-----------------------------+------------------| + // | None | false | Less | gen 1x, stop | + // | None | false | Equal | gen 1x, stop | + // | None | false | Greater | gen 1x, stop | + // | None | true | Less | gen 1x, stop | + // | None | true | Equal | gen 1x, stop | + // | None | true | Greater | gen 1x, stop | + + mod without_sampling_interval { + use super::*; + + mod without_continue { + use super::*; + + #[tokio::test] + async fn current_time_less_than_end_time() -> Result<()> { + let mut agent = Agent::test_instance(None, false, 0, 10); + + let points = agent.generate().await?.into_iter().flatten(); + assert_eq!(points.count(), 2); + + let points = agent.generate().await?.into_iter().flatten(); + let points: Vec<_> = points.collect(); + assert!(points.is_empty(), "expected no points, got {points:?}"); + + Ok(()) + } + + #[tokio::test] + async fn current_time_equal_end_time() -> Result<()> { + let mut agent = Agent::test_instance(None, false, 10, 10); + + let points = agent.generate().await?.into_iter().flatten(); + assert_eq!(points.count(), 2); + + let points = agent.generate().await?.into_iter().flatten(); + let points: Vec<_> = points.collect(); + assert!(points.is_empty(), "expected no points, got {points:?}"); + + Ok(()) + } + + #[tokio::test] + async fn current_time_greater_than_end_time() -> Result<()> { + let mut agent = Agent::test_instance(None, false, 11, 10); + + let points = agent.generate().await?.into_iter().flatten(); + assert_eq!(points.count(), 2); + + let points = agent.generate().await?.into_iter().flatten(); + let points: Vec<_> = points.collect(); + assert!(points.is_empty(), "expected no points, got {points:?}"); + + Ok(()) + } + } + + mod with_continue { + use super::*; + + #[tokio::test] + async fn current_time_less_than_end_time() -> Result<()> { + let mut agent = Agent::test_instance(None, true, 0, 10); + + let points = agent.generate().await?.into_iter().flatten(); + assert_eq!(points.count(), 2); + + let points = agent.generate().await?.into_iter().flatten(); + let points: Vec<_> = points.collect(); + assert!(points.is_empty(), "expected no points, got {points:?}"); + + Ok(()) + } + + #[tokio::test] + async fn current_time_equal_end_time() -> Result<()> { + let mut agent = Agent::test_instance(None, true, 10, 10); + + let points = agent.generate().await?.into_iter().flatten(); + assert_eq!(points.count(), 2); + + let points = agent.generate().await?.into_iter().flatten(); + let points: Vec<_> = points.collect(); + assert!(points.is_empty(), "expected no points, got {points:?}"); + + Ok(()) + } + + #[tokio::test] + async fn current_time_greater_than_end_time() -> Result<()> { + let mut agent = Agent::test_instance(None, true, 11, 10); + + let points = agent.generate().await?.into_iter().flatten(); + assert_eq!(points.count(), 2); + + let points = agent.generate().await?.into_iter().flatten(); + let points: Vec<_> = points.collect(); + assert!(points.is_empty(), "expected no points, got {points:?}"); + + Ok(()) + } + } + } + + mod with_sampling_interval { + use super::*; + + // The tests take about 5 ms to run on my computer, so set the sampling interval + // to 10 ms to be able to test that the delay is happening when + // `continue` is true without making the tests too artificially slow. + const TEST_SAMPLING_INTERVAL: Duration = Duration::from_millis(10); + + #[rustfmt::skip] + // # Summary: Not continuing + // + // If there is a sampling interval but we're not continuing, we should generate points at + // least once but if the current time is greater than the ending time (which might be set + // to `now`), we've generated everything we need to and should stop. + // + // | sampling_interval | continue | cmp(current_time, end_time) | expected outcome | + // |-------------------+----------+-----------------------------+------------------| + // | Some(_) | false | Less | gen & increment | + // | Some(_) | false | Equal | gen 1x, stop | + // | Some(_) | false | Greater | gen 1x, stop | + + mod without_continue { + use super::*; + + #[tokio::test] + async fn current_time_less_than_end_time() -> Result<()> { + let current = 0; + let end = TEST_SAMPLING_INTERVAL.as_nanos() as i64; + + let mut agent = + Agent::test_instance(Some(TEST_SAMPLING_INTERVAL), false, current, end); + + let points = agent.generate().await?.into_iter().flatten(); + assert_eq!(points.count(), 2); + + let points = agent.generate().await?.into_iter().flatten(); + assert_eq!(points.count(), 2); + + let points = agent.generate().await?.into_iter().flatten(); + let points: Vec<_> = points.collect(); + assert!(points.is_empty(), "expected no points, got {points:?}"); + + Ok(()) + } + + #[tokio::test] + async fn current_time_equal_end_time() -> Result<()> { + let current = TEST_SAMPLING_INTERVAL.as_nanos() as i64; + let end = current; + + let mut agent = + Agent::test_instance(Some(TEST_SAMPLING_INTERVAL), false, current, end); + + let points = agent.generate().await?.into_iter().flatten(); + assert_eq!(points.count(), 2); + + let points = agent.generate().await?.into_iter().flatten(); + let points: Vec<_> = points.collect(); + assert!(points.is_empty(), "expected no points, got {points:?}"); + + Ok(()) + } + + #[tokio::test] + async fn current_time_greater_than_end_time() -> Result<()> { + let current = 2 * TEST_SAMPLING_INTERVAL.as_nanos() as i64; + let end = TEST_SAMPLING_INTERVAL.as_nanos() as i64; + + let mut agent = + Agent::test_instance(Some(TEST_SAMPLING_INTERVAL), false, current, end); + + let points = agent.generate().await?.into_iter().flatten(); + assert_eq!(points.count(), 2); + + let points = agent.generate().await?.into_iter().flatten(); + let points: Vec<_> = points.collect(); + assert!(points.is_empty(), "expected no points, got {points:?}"); + + Ok(()) + } + } + + #[rustfmt::skip] + // # Summary: After generating historical data, continue sampling in "real time" + // + // If there is a sampling interval and we are continuing, generate points as fast as + // possible (but with timestamps separated by sampling_interval amounts) until we catch up + // to `now`. Then add pauses of the sampling_interval's duration, generating points with + // their timestamps set to the current time to simulate "real" point generation. + // + // | sampling_interval | continue | cmp(current_time, end_time) | expected outcome | + // |-------------------+----------+-----------------------------+------------------| + // | Some(_) | true | Less | gen, no delay | + // | Some(_) | true | Equal | gen, delay | + // | Some(_) | true | Greater | gen, delay | + + mod with_continue { + use super::*; + + #[tokio::test] + async fn current_time_less_than_end_time() -> Result<()> { + let end = now_ns(); + let current = end - TEST_SAMPLING_INTERVAL.as_nanos() as i64; + + let mut agent = + Agent::test_instance(Some(TEST_SAMPLING_INTERVAL), true, current, end); + + let points = agent.generate().await?.into_iter().flatten(); + let points: Vec<_> = points.collect(); + assert_eq!(points.len(), 2); + + let times = timestamps(&points).unwrap(); + assert_eq!(vec![current, current], times); + + let points = agent.generate().await?.into_iter().flatten(); + let points: Vec<_> = points.collect(); + assert_eq!(points.len(), 2); + + let times = timestamps(&points).unwrap(); + assert_eq!(vec![end, end], times); + + Ok(()) + } + + #[tokio::test] + async fn current_time_equal_end_time() -> Result<()> { + let end = now_ns(); + let current = end; + + let mut agent = + Agent::test_instance(Some(TEST_SAMPLING_INTERVAL), true, current, end); + + let points = agent.generate().await?.into_iter().flatten(); + let points: Vec<_> = points.collect(); + assert_eq!(points.len(), 2); + + let times = timestamps(&points).unwrap(); + assert_eq!(vec![end, end], times); + + let points = agent.generate().await?.into_iter().flatten(); + let points: Vec<_> = points.collect(); + assert_eq!(points.len(), 2); + + let real_now = now_ns(); + + let times = timestamps(&points).unwrap(); + for time in times { + assert!( + time <= real_now, + "expected timestamp {} to be generated before now ({}); \ + was {} nanoseconds greater", + time, + real_now, + time - real_now + ); + } + + Ok(()) + } + + #[tokio::test] + async fn current_time_greater_than_end_time() -> Result<()> { + let end = now_ns(); + let current = end + TEST_SAMPLING_INTERVAL.as_nanos() as i64; + + let mut agent = + Agent::test_instance(Some(TEST_SAMPLING_INTERVAL), true, current, end); + + let points = agent.generate().await?.into_iter().flatten(); + let points: Vec<_> = points.collect(); + assert_eq!(points.len(), 2); + + let times = timestamps(&points).unwrap(); + assert_eq!(vec![current, current], times); + + let points = agent.generate().await?.into_iter().flatten(); + let points: Vec<_> = points.collect(); + assert_eq!(points.len(), 2); + + let real_now = now_ns(); + + let times = timestamps(&points).unwrap(); + for time in times { + assert!( + time <= real_now, + "expected timestamp {} to be generated before now ({}); \ + was {} nanoseconds greater", + time, + real_now, + time - real_now + ); + } + + Ok(()) + } + } + } +} diff --git a/iox_data_generator/src/bin/iox_data_generator.rs b/iox_data_generator/src/bin/iox_data_generator.rs new file mode 100644 index 0000000..3355b28 --- /dev/null +++ b/iox_data_generator/src/bin/iox_data_generator.rs @@ -0,0 +1,268 @@ +//! Entry point for generator CLI. +#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)] +#![warn( + missing_copy_implementations, + missing_debug_implementations, + missing_docs, + clippy::explicit_iter_loop, + // See https://github.com/influxdata/influxdb_iox/pull/1671 + clippy::future_not_send, + clippy::use_self, + clippy::clone_on_ref_ptr +)] + +use chrono::prelude::*; +use iox_data_generator::{specification::DataSpec, write::PointsWriterBuilder}; +use std::{ + fs::File, + io::{self, BufRead}, +}; +use tracing::info; + +#[derive(clap::Parser)] +#[clap( + name = "iox_data_generator", + about = "IOx data point generator", + long_about = r#"IOx data point generator + +Examples: + # Generate data points using the specification in `spec.toml` and save in the `lp` directory + iox_data_generator -s spec.toml -o lp + + # Generate data points and write to the server running at localhost:8080 with the provided org, + # bucket and authorization token + iox_data_generator -s spec.toml -h localhost:8080 --org myorg --bucket mybucket --token mytoken + + # Generate data points for the 24 hours between midnight 2020-01-01 and 2020-01-02 + iox_data_generator -s spec.toml -o lp --start 2020-01-01 --end 2020-01-02 + + # Generate data points starting from an hour ago until now, generating the historical data as + # fast as possible. Then generate data according to the sampling interval until terminated. + iox_data_generator -s spec.toml -o lp --start "1 hr" --continue + +Logging: + Use the RUST_LOG environment variable to configure the desired logging level. + For example: + + # Enable INFO level logging for all of iox_data_generator + RUST_LOG=iox_data_generator=info iox_data_generator -s spec.toml -o lp +"#, + author, + version, + disable_help_flag = true, + arg( + clap::Arg::new("help") + .long("help") + .help("Print help information") + .action(clap::ArgAction::Help) + .global(true) + ), +)] +struct Config { + /// Path to the specification TOML file describing the data generation + #[clap(long, short, action)] + specification: String, + + /// Print the generated line protocol from a single sample collection to the terminal + #[clap(long, action)] + print: bool, + + /// Runs the generation with agents writing to a sink. Useful for quick stress test to see how + /// much resources the generator will take + #[clap(long, action)] + noop: bool, + + /// The directory to write line protocol to + #[clap(long, short, action)] + output: Option, + + /// The directory to write Parquet files to + #[clap(long, short, action)] + parquet: Option, + + /// The host name part of the API endpoint to write to + #[clap(long, short, action)] + host: Option, + + /// The organization name to write to + #[clap(long, action)] + org: Option, + + /// The bucket name to write to + #[clap(long, action)] + bucket: Option, + + /// File name with a list of databases. 1 per line with _ format + #[clap(long, action)] + database_list: Option, + + /// The API authorization token used for all requests + #[clap(long, action)] + token: Option, + + /// The date and time at which to start the timestamps of the generated data. + /// + /// Can be an exact datetime like `2020-01-01T01:23:45-05:00` or a fuzzy + /// specification like `1 hour`. If not specified, defaults to no. + #[clap(long, action)] + start: Option, + + /// The date and time at which to stop the timestamps of the generated data. + /// + /// Can be an exact datetime like `2020-01-01T01:23:45-05:00` or a fuzzy + /// specification like `1 hour`. If not specified, defaults to now. + #[clap(long, action)] + end: Option, + + /// Generate live data using the intervals from the spec after generating historical data. + /// + /// This option has no effect if you specify an end time. + #[clap(long = "continue", action)] + do_continue: bool, + + /// Generate this many samplings to batch into a single API call. Good for sending a bunch of + /// historical data in quickly if paired with a start time from long ago. + #[clap(long, action, default_value = "1")] + batch_size: usize, + + /// Generate jaeger debug header with given key during write + #[clap(long, action)] + jaeger_debug_header: Option, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let config: Config = clap::Parser::parse(); + + if !config.print { + tracing_subscriber::fmt::init(); + } + + let execution_start_time = Local::now(); + let execution_start_time_nanos = execution_start_time + .timestamp_nanos_opt() + .expect("'now' is in nano range"); + + let start_datetime = datetime_nanoseconds(config.start.as_deref(), execution_start_time); + let end_datetime = datetime_nanoseconds(config.end.as_deref(), execution_start_time); + + let start_display = start_datetime.unwrap_or(execution_start_time_nanos); + let end_display = end_datetime.unwrap_or(execution_start_time_nanos); + + let continue_on = config.do_continue; + + info!( + "Starting at {}, ending at {} ({}){}", + start_display, + end_display, + (end_display - start_display) / 1_000_000_000, + if continue_on { " then continuing" } else { "" }, + ); + + let data_spec = DataSpec::from_file(&config.specification)?; + + let mut points_writer_builder = if let Some(line_protocol_filename) = config.output { + PointsWriterBuilder::new_file(line_protocol_filename)? + } else if let Some(parquet_directory) = config.parquet { + PointsWriterBuilder::new_parquet(parquet_directory)? + } else if let Some(ref host) = config.host { + let token = config.token.expect("--token must be specified"); + + PointsWriterBuilder::new_api(host, token, config.jaeger_debug_header.as_deref()).await? + } else if config.print { + PointsWriterBuilder::new_std_out() + } else if config.noop { + PointsWriterBuilder::new_no_op(true) + } else { + panic!("One of --print or --output or --host must be provided."); + }; + + let buckets = if config.host.is_some() { + // Buckets are only relevant if we're writing to the API + match (config.org, config.bucket, config.database_list) { + (Some(org), Some(bucket), None) => { + vec![format!("{org}_{bucket}")] + } + (None, None, Some(bucket_list)) => { + let f = File::open(bucket_list).expect("unable to open database_list file"); + + io::BufReader::new(f) + .lines() + .map(|l| l.expect("unable to read database from database_list file")) + .collect::>() + } + _ => panic!("must specify either --org AND --bucket OR --database_list"), + } + } else { + // But we need at least one database or nothing will be written anywhere + vec![String::from("org_bucket")] + }; + + let result = iox_data_generator::generate( + &data_spec, + buckets, + &mut points_writer_builder, + start_datetime, + end_datetime, + execution_start_time_nanos, + continue_on, + config.batch_size, + config.print, + ) + .await; + + match result { + Ok(total_points) => { + if !config.print { + eprintln!("Submitted {total_points} total points"); + } + } + Err(e) => eprintln!("Execution failed: \n{e}"), + } + + Ok(()) +} + +fn datetime_nanoseconds(arg: Option<&str>, now: DateTime) -> Option { + arg.map(|s| { + let datetime = humantime::parse_rfc3339(s) + .map(Into::into) + .unwrap_or_else(|_| { + let std_duration = humantime::parse_duration(s).expect("Could not parse time"); + let chrono_duration = chrono::Duration::from_std(std_duration) + .expect("Could not convert std::time::Duration to chrono::Duration"); + now - chrono_duration + }); + + datetime + .timestamp_nanos_opt() + .expect("timestamp out of range") + }) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn none_datetime_is_none_nanoseconds() { + let ns = datetime_nanoseconds(None, Local::now()); + assert!(ns.is_none()); + } + + #[test] + fn rfc3339() { + let ns = datetime_nanoseconds(Some("2020-01-01T01:23:45Z"), Local::now()); + assert_eq!(ns, Some(1_577_841_825_000_000_000)); + } + + #[test] + fn relative() { + let fixed_now = Local::now(); + let ns = datetime_nanoseconds(Some("1hr"), fixed_now); + let expected = (fixed_now - chrono::Duration::hours(1)) + .timestamp_nanos_opt() + .unwrap(); + assert_eq!(ns, Some(expected)); + } +} diff --git a/iox_data_generator/src/field.rs b/iox_data_generator/src/field.rs new file mode 100644 index 0000000..a32a9d8 --- /dev/null +++ b/iox_data_generator/src/field.rs @@ -0,0 +1,546 @@ +//! Generating a set of field keys and values given a specification + +use crate::{ + now_ns, specification, + substitution::{self, pick_from_replacements}, +}; + +use handlebars::Handlebars; +use rand::rngs::SmallRng; +use rand::Rng; +use rand::SeedableRng; +use serde_json::json; +use serde_json::Value; +use snafu::{ResultExt, Snafu}; +use std::{ops::Range, time::Duration}; + +/// Field-specific Results +pub type Result = std::result::Result; + +/// Errors that may happen while creating fields +#[derive(Snafu, Debug)] +#[allow(missing_docs)] +pub enum Error { + #[snafu(display("Could not create field name, caused by:\n{}", source))] + CouldNotCreateFieldName { source: crate::substitution::Error }, + + #[snafu(display("Could not compile string field template: {}", source))] + CouldNotCompileStringTemplate { + #[snafu(source(from(handlebars::TemplateError, Box::new)))] + source: Box, + }, + + #[snafu(display("Could not render string field template: {}", source))] + CouldNotRenderStringTemplate { + #[snafu(source(from(handlebars::RenderError, Box::new)))] + source: Box, + }, +} + +/// Different field type generators +#[derive(Debug)] +pub enum FieldGeneratorImpl { + /// Boolean field generator + Bool(BooleanFieldGenerator), + /// Integer field generator + I64(I64FieldGenerator), + /// Float field generator + F64(F64FieldGenerator), + /// String field generator + String(Box), + /// Uptime field generator + Uptime(UptimeFieldGenerator), +} + +impl FieldGeneratorImpl { + /// Create fields that will generate according to the spec + pub fn from_spec( + spec: &specification::FieldSpec, + data: Value, + execution_start_time: i64, + ) -> Result> { + use specification::FieldValueSpec::*; + + let field_count = spec.count.unwrap_or(1); + + let mut fields = Vec::with_capacity(field_count); + + for field_id in 1..field_count + 1 { + let mut data = data.clone(); + let d = data.as_object_mut().expect("data must be object"); + d.insert("field".to_string(), json!({ "id": field_id })); + + let field_name = substitution::render_once("field", &spec.name, &data) + .context(CouldNotCreateFieldNameSnafu)?; + + let rng = + SmallRng::from_rng(&mut rand::thread_rng()).expect("SmallRng should always create"); + + let field = match &spec.field_value_spec { + Bool(true) => Self::Bool(BooleanFieldGenerator::new(&field_name, rng)), + Bool(false) => unimplemented!("Not sure what false means for bool fields yet"), + I64 { + range, + increment, + reset_after, + } => Self::I64(I64FieldGenerator::new( + &field_name, + range, + *increment, + *reset_after, + rng, + )), + F64 { range } => Self::F64(F64FieldGenerator::new(&field_name, range, rng)), + String { + pattern, + replacements, + } => Self::String(Box::new(StringFieldGenerator::new( + &field_name, + pattern, + data, + replacements.to_vec(), + rng, + )?)), + Uptime { kind } => Self::Uptime(UptimeFieldGenerator::new( + &field_name, + kind, + execution_start_time, + )), + }; + + fields.push(field); + } + + Ok(fields) + } + + /// Writes the field in line protocol to the passed writer + pub fn write_to(&mut self, mut w: W, timestamp: i64) -> std::io::Result<()> { + match self { + Self::Bool(f) => { + let v: bool = f.rng.gen(); + write!(w, "{}={}", f.name, v) + } + Self::I64(f) => { + let v = f.generate_value(); + write!(w, "{}={}", f.name, v) + } + Self::F64(f) => { + let v = f.generate_value(); + write!(w, "{}={}", f.name, v) + } + Self::String(f) => { + let v = f.generate_value(timestamp); + write!(w, "{}=\"{}\"", f.name, v) + } + Self::Uptime(f) => match f.kind { + specification::UptimeKind::I64 => { + let v = f.generate_value(); + write!(w, "{}={}", f.name, v) + } + specification::UptimeKind::Telegraf => { + let v = f.generate_value_as_string(); + write!(w, "{}=\"{}\"", f.name, v) + } + }, + } + } +} + +/// Generate boolean field names and values. +#[derive(Debug)] +pub struct BooleanFieldGenerator { + /// The name (key) of the field + pub name: String, + rng: SmallRng, +} + +impl BooleanFieldGenerator { + /// Create a new boolean field generator that will always use the specified + /// name. + pub fn new(name: &str, rng: SmallRng) -> Self { + let name = name.into(); + + Self { name, rng } + } + + /// Generate a random value + pub fn generate_value(&mut self) -> bool { + self.rng.gen() + } +} + +/// Generate integer field names and values. +#[derive(Debug)] +pub struct I64FieldGenerator { + /// The name (key) of the field + pub name: String, + range: Range, + increment: bool, + rng: SmallRng, + previous_value: i64, + reset_after: Option, + current_tick: usize, +} + +impl I64FieldGenerator { + /// Create a new integer field generator that will always use the specified + /// name. + pub fn new( + name: impl Into, + range: &Range, + increment: bool, + reset_after: Option, + rng: SmallRng, + ) -> Self { + Self { + name: name.into(), + range: range.to_owned(), + increment, + rng, + previous_value: 0, + reset_after, + current_tick: 0, + } + } + + /// Generate a random value + pub fn generate_value(&mut self) -> i64 { + let mut value = if self.range.start == self.range.end { + self.range.start + } else { + self.rng.gen_range(self.range.clone()) + }; + + if self.increment { + self.previous_value = self.previous_value.wrapping_add(value); + value = self.previous_value; + + if let Some(reset) = self.reset_after { + self.current_tick += 1; + if self.current_tick >= reset { + self.previous_value = 0; + self.current_tick = 0; + } + } + } + + value + } +} + +/// Generate floating point field names and values. +#[derive(Debug)] +pub struct F64FieldGenerator { + /// The name (key) of the field + pub name: String, + range: Range, + rng: SmallRng, +} + +impl F64FieldGenerator { + /// Create a new floating point field generator that will always use the + /// specified name. + pub fn new(name: impl Into, range: &Range, rng: SmallRng) -> Self { + Self { + name: name.into(), + range: range.to_owned(), + rng, + } + } + + /// Generate a random value + pub fn generate_value(&mut self) -> f64 { + if (self.range.start - self.range.end).abs() < f64::EPSILON { + self.range.start + } else { + self.rng.gen_range(self.range.clone()) + } + } +} + +/// Generate string field names and values. +#[derive(Debug)] +pub struct StringFieldGenerator { + /// The name (key) of the field + pub name: String, + rng: SmallRng, + replacements: Vec, + handlebars: Handlebars<'static>, + data: Value, +} + +impl StringFieldGenerator { + /// Create a new string field generator + pub fn new( + name: impl Into, + template: impl Into, + data: Value, + replacements: Vec, + rng: SmallRng, + ) -> Result { + let name = name.into(); + let mut registry = substitution::new_handlebars_registry(); + registry + .register_template_string(&name, template.into()) + .context(CouldNotCompileStringTemplateSnafu)?; + + Ok(Self { + name, + rng, + replacements, + handlebars: registry, + data, + }) + } + + /// Generate a random value + pub fn generate_value(&mut self, timestamp: i64) -> String { + let replacements = pick_from_replacements(&mut self.rng, &self.replacements); + let d = self.data.as_object_mut().expect("data must be object"); + + if replacements.is_empty() { + d.remove("replacements"); + } else { + d.insert("replacements".to_string(), json!(replacements)); + } + + d.insert("timestamp".to_string(), json!(timestamp)); + + self.handlebars + .render(&self.name, &self.data) + .expect("Unable to substitute string field value") + } +} + +/// Generate an i64 field that has the name `uptime` and the value of the number +/// of seconds since the data generator started running +#[derive(Debug)] +pub struct UptimeFieldGenerator { + /// The name (key) of the field + pub name: String, + execution_start_time: i64, + /// The specification type of the uptime field. Either an int64 or a string + pub kind: specification::UptimeKind, +} + +impl UptimeFieldGenerator { + fn new( + name: impl Into, + kind: &specification::UptimeKind, + execution_start_time: i64, + ) -> Self { + Self { + name: name.into(), + kind: *kind, + execution_start_time, + } + } + + /// Generates the uptime as an i64 + pub fn generate_value(&mut self) -> i64 { + let elapsed = Duration::from_nanos((now_ns() - self.execution_start_time) as u64); + elapsed.as_secs() as i64 + } + + /// Generates the uptime as a string, which is what should be used if `self.kind == specification::UptimeKind::Telegraf` + pub fn generate_value_as_string(&mut self) -> String { + let elapsed_seconds = self.generate_value(); + let days = elapsed_seconds / (60 * 60 * 24); + let days_plural = if days == 1 { "" } else { "s" }; + + let mut minutes = elapsed_seconds / 60; + let mut hours = minutes / 60; + hours %= 24; + minutes %= 60; + + format!("{days} day{days_plural}, {hours:02}:{minutes:02}") + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::specification::UptimeKind; + use rand::SeedableRng; + use test_helpers::approximately_equal; + + #[test] + fn generate_i64_field_always_the_same() { + // If the specification has the same number for the start and end of the + // range... + let mut i64fg = + I64FieldGenerator::new("i64fg", &(3..3), false, None, SmallRng::from_entropy()); + + let i64_fields: Vec<_> = (0..10).map(|_| i64fg.generate_value()).collect(); + let expected = i64_fields[0]; + + // All the values generated will always be the same. + assert!(i64_fields.iter().all(|f| *f == expected), "{i64_fields:?}"); + + // If the specification has n for the start and n+1 for the end of the range... + let mut i64fg = + I64FieldGenerator::new("i64fg", &(4..5), false, None, SmallRng::from_entropy()); + + let i64_fields: Vec<_> = (0..10).map(|_| i64fg.generate_value()).collect(); + // We know what the value will be even though we're using a real random number generator + let expected = 4; + + // All the values generated will also always be the same, because the end of the + // range is exclusive. + assert!(i64_fields.iter().all(|f| *f == expected), "{i64_fields:?}"); + } + + #[test] + fn generate_i64_field_within_a_range() { + let range = 3..1000; + + let mut i64fg = + I64FieldGenerator::new("i64fg", &range, false, None, SmallRng::from_entropy()); + + let val = i64fg.generate_value(); + + assert!(range.contains(&val), "`{val}` was not in the range"); + } + + #[test] + fn generate_incrementing_i64_field() { + let mut i64fg = + I64FieldGenerator::new("i64fg", &(3..10), true, None, SmallRng::from_entropy()); + + let val1 = i64fg.generate_value(); + let val2 = i64fg.generate_value(); + let val3 = i64fg.generate_value(); + let val4 = i64fg.generate_value(); + + assert!(val1 < val2, "`{val1}` < `{val2}` was false"); + assert!(val2 < val3, "`{val2}` < `{val3}` was false"); + assert!(val3 < val4, "`{val3}` < `{val4}` was false"); + } + + #[test] + fn incrementing_i64_wraps() { + let rng = SmallRng::from_entropy(); + let range = 3..10; + let previous_value = i64::MAX; + + // Construct by hand to set the previous value at the end of i64's range + let mut i64fg = I64FieldGenerator { + name: "i64fg".into(), + range: range.clone(), + increment: true, + reset_after: None, + rng, + previous_value, + current_tick: 0, + }; + + let resulting_range = + range.start.wrapping_add(previous_value)..range.end.wrapping_add(previous_value); + + let val = i64fg.generate_value(); + + assert!( + resulting_range.contains(&val), + "`{val}` was not in the range" + ); + } + + #[test] + fn incrementing_i64_that_resets() { + let reset_after = Some(3); + let mut i64fg = I64FieldGenerator::new( + "i64fg", + &(3..8), + true, + reset_after, + SmallRng::from_entropy(), + ); + + let val1 = i64fg.generate_value(); + let val2 = i64fg.generate_value(); + let val3 = i64fg.generate_value(); + let val4 = i64fg.generate_value(); + + assert!(val1 < val2, "`{val1}` < `{val2}` was false"); + assert!(val2 < val3, "`{val2}` < `{val3}` was false"); + assert!(val4 < val3, "`{val4}` < `{val3}` was false"); + } + + #[test] + fn generate_f64_field_always_the_same() { + // If the specification has the same number for the start and end of the + // range... + let start_and_end = 3.0; + let range = start_and_end..start_and_end; + let mut f64fg = F64FieldGenerator::new("f64fg", &range, SmallRng::from_entropy()); + + let f64_fields: Vec<_> = (0..10).map(|_| f64fg.generate_value()).collect(); + + // All the values generated will always be the same known value. + assert!( + f64_fields + .iter() + .all(|f| approximately_equal(*f, start_and_end)), + "{f64_fields:?}" + ); + } + + #[test] + fn generate_f64_field_within_a_range() { + let range = 3.0..1000.0; + let mut f64fg = F64FieldGenerator::new("f64fg", &range, SmallRng::from_entropy()); + + let val = f64fg.generate_value(); + assert!(range.contains(&val), "`{val}` was not in the range"); + } + + #[test] + fn generate_string_field_with_data() { + let fake_now = 1633595510000000000; + + let mut stringfg = StringFieldGenerator::new( + "str", + r#"my value {{measurement.name}} {{format-time "%Y-%m-%d"}}"#, + json!({"measurement": {"name": "foo"}}), + vec![], + SmallRng::from_entropy(), + ) + .unwrap(); + + assert_eq!("my value foo 2021-10-07", stringfg.generate_value(fake_now)); + } + + #[test] + fn uptime_i64() { + // Pretend data generator started running 10 seconds ago + let seconds_ago = 10; + let execution_start_time = now_ns() - seconds_ago * 1_000_000_000; + let mut uptimefg = UptimeFieldGenerator::new("foo", &UptimeKind::I64, execution_start_time); + + assert_eq!(seconds_ago, uptimefg.generate_value()); + } + + #[test] + fn uptime_telegraf() { + // Pretend data generator started running 10 days, 2 hours, and 33 minutes ago + let seconds_ago = 10 * 24 * 60 * 60 + 2 * 60 * 60 + 33 * 60; + let execution_start_time = now_ns() - seconds_ago * 1_000_000_000; + let mut uptimefg = UptimeFieldGenerator::new("foo", &UptimeKind::I64, execution_start_time); + + assert_eq!("10 days, 02:33", uptimefg.generate_value_as_string()); + + // Pretend data generator started running 1 day, 14 hours, and 5 minutes ago + // to exercise different formatting + let seconds_in_1_day = 24 * 60 * 60; + let seconds_in_14_hours = 14 * 60 * 60; + let seconds_in_5_minutes = 5 * 60; + + let seconds_ago = seconds_in_1_day + seconds_in_14_hours + seconds_in_5_minutes; + let execution_start_time = now_ns() - seconds_ago * 1_000_000_000; + + let mut uptimefg = UptimeFieldGenerator::new("foo", &UptimeKind::I64, execution_start_time); + + assert_eq!("1 day, 14:05", uptimefg.generate_value_as_string()); + } +} diff --git a/iox_data_generator/src/lib.rs b/iox_data_generator/src/lib.rs new file mode 100644 index 0000000..a496e61 --- /dev/null +++ b/iox_data_generator/src/lib.rs @@ -0,0 +1,343 @@ +//! This crate contains structures and generators for specifying how to generate +//! historical and real-time test data for Delorean. The rules for how to +//! generate data and what shape it should take can be specified in a TOML file. +//! +//! Generators can output in line protocol, Parquet, or can be used to generate +//! real-time load on a server that implements the [InfluxDB 2.0 write +//! path][write-api]. +//! +//! [write-api]: https://v2.docs.influxdata.com/v2.0/api/#tag/Write +//! +//! While this generator could be compared to [the Go based one that creates TSM +//! data][go-gen], its purpose is meant to be more far reaching. In addition to +//! generating historical data, it should be useful for generating data in a +//! sequence as you would expect it to arrive in a production environment. That +//! means many agents sending data with their different tags and timestamps. +//! +//! [go-gen]: https://github.com/influxdata/influxdb/pull/12710 + +#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)] +#![warn( + missing_copy_implementations, + missing_debug_implementations, + missing_docs, + clippy::explicit_iter_loop, + // See https://github.com/influxdata/influxdb_iox/pull/1671 + clippy::future_not_send, + clippy::use_self, + clippy::clone_on_ref_ptr, + clippy::todo, + clippy::dbg_macro, + unused_crate_dependencies +)] + +// Workaround for "unused crate" lint false positives. +use clap as _; +#[cfg(test)] +use criterion as _; +use tracing_subscriber as _; + +use crate::{ + agent::{Agent, AgentGenerateStats}, + tag_set::GeneratedTagSets, +}; +use snafu::{ResultExt, Snafu}; +use std::{ + convert::TryFrom, + sync::{atomic::AtomicU64, Arc}, + time::{SystemTime, UNIX_EPOCH}, +}; + +pub mod agent; +pub mod field; +pub mod measurement; +pub mod specification; +pub mod substitution; +mod tag_pair; +pub mod tag_set; +pub mod write; + +/// Errors that may happen while generating points. +#[derive(Snafu, Debug)] +pub enum Error { + /// Error that may happen when waiting on a tokio task + #[snafu(display("Could not join tokio task: {}", source))] + TokioError { + /// Underlying tokio error that caused this problem + source: tokio::task::JoinError, + }, + + /// Error that may happen when constructing an agent name + #[snafu(display("Could not create agent name, caused by:\n{}", source))] + CouldNotCreateAgentName { + /// Underlying `substitution` module error that caused this problem + source: substitution::Error, + }, + + /// Error that may happen when an agent generates points + #[snafu(display("Agent could not generate points, caused by:\n{}", source))] + AgentCouldNotGeneratePoints { + /// Underlying `agent` module error that caused this problem + source: agent::Error, + }, + + /// Error that may happen when creating agents + #[snafu(display("Could not create agents, caused by:\n{}", source))] + CouldNotCreateAgent { + /// Underlying `agent` module error that caused this problem + source: agent::Error, + }, + + /// Error that may happen when constructing an agent's writer + #[snafu(display("Could not create writer for agent, caused by:\n{}", source))] + CouldNotCreateAgentWriter { + /// Underlying `write` module error that caused this problem + source: write::Error, + }, + + /// Error generating tags sets + #[snafu(display("Error generating tag sets prior to creating agents: \n{}", source))] + CouldNotGenerateTagSets { + /// Underlying `tag_set` module error + source: tag_set::Error, + }, + + /// Error splitting input buckets to agents that write to them + #[snafu(display( + "Error splitting input buckets into agents that write to them: {}", + source + ))] + CouldNotAssignAgents { + /// Underlying `specification` module error + source: specification::Error, + }, +} + +type Result = std::result::Result; + +/// Generate data from the configuration in the spec. +/// +/// Provide a writer that the line protocol should be written to. +/// +/// If `start_datetime` or `end_datetime` are `None`, the current datetime will +/// be used. +#[allow(clippy::too_many_arguments)] +pub async fn generate( + spec: &specification::DataSpec, + databases: Vec, + points_writer_builder: &mut write::PointsWriterBuilder, + start_datetime: Option, + end_datetime: Option, + execution_start_time: i64, + continue_on: bool, + batch_size: usize, + one_agent_at_a_time: bool, // run one agent after another, if printing to stdout +) -> Result { + let mut handles = vec![]; + + let database_agents = spec + .database_split_to_agents(&databases) + .context(CouldNotAssignAgentsSnafu)?; + + let generated_tag_sets = + GeneratedTagSets::from_spec(spec).context(CouldNotGenerateTagSetsSnafu)?; + + let lock = Arc::new(tokio::sync::Mutex::new(())); + + let start = std::time::Instant::now(); + let total_rows = Arc::new(AtomicU64::new(0)); + let total_requests = Arc::new(AtomicU64::new(0)); + + for database_assignments in &database_agents { + let (org, bucket) = org_and_bucket_from_database(database_assignments.database); + + for agent_assignment in database_assignments.agent_assignments.iter() { + let agents = Agent::from_spec( + agent_assignment.spec, + agent_assignment.count, + agent_assignment.sampling_interval, + start_datetime, + end_datetime, + execution_start_time, + continue_on, + &generated_tag_sets, + ) + .context(CouldNotCreateAgentSnafu)?; + + println!( + "Configuring {} agents of \"{}\" to write data \ + to org {} and bucket {} (database {})", + agent_assignment.count, + agent_assignment.spec.name, + org, + bucket, + database_assignments.database, + ); + + let agent_points_writer = Arc::new( + points_writer_builder + .build_for_agent(&agent_assignment.spec.name, org, bucket) + .context(CouldNotCreateAgentWriterSnafu)?, + ); + + for mut agent in agents.into_iter() { + let lock_ref = Arc::clone(&lock); + let agent_points_writer = Arc::clone(&agent_points_writer); + + let total_rows = Arc::clone(&total_rows); + let total_requests = Arc::clone(&total_requests); + handles.push(tokio::task::spawn(async move { + // did this weird hack because otherwise the stdout outputs would be jumbled + // together garbage + if one_agent_at_a_time { + let _l = lock_ref.lock().await; + agent + .generate_all( + agent_points_writer, + batch_size, + total_rows, + total_requests, + ) + .await + } else { + agent + .generate_all( + agent_points_writer, + batch_size, + total_rows, + total_requests, + ) + .await + } + })); + } + } + } + + let mut stats = vec![]; + for handle in handles { + stats.push( + handle + .await + .context(TokioSnafu)? + .context(AgentCouldNotGeneratePointsSnafu)?, + ); + } + let stats = stats + .into_iter() + .fold(AgentGenerateStats::default(), |totals, res| { + AgentGenerateStats { + request_count: totals.request_count + res.request_count, + error_count: totals.error_count + res.error_count, + row_count: totals.row_count + res.row_count, + } + }); + + println!("{}", stats.display_stats(start.elapsed())); + + Ok(stats.row_count) +} + +/// Gets the current time in nanoseconds since the epoch +pub fn now_ns() -> i64 { + let since_the_epoch = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards"); + i64::try_from(since_the_epoch.as_nanos()).expect("Time does not fit") +} + +fn org_and_bucket_from_database(database: &str) -> (&str, &str) { + let parts = database.split('_').collect::>(); + if parts.len() != 2 { + panic!("error parsing org and bucket from {database}"); + } + + (parts[0], parts[1]) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::specification::*; + use influxdb2_client::models::WriteDataPoint; + use std::str::FromStr; + use std::time::Duration; + + type Error = Box; + type Result = std::result::Result; + + #[tokio::test] + async fn historical_data_sampling_interval() -> Result<()> { + let toml = r#" +name = "demo_schema" + +[[agents]] +name = "foo" + +[[agents.measurements]] +name = "cpu" + +[[agents.measurements.fields]] +name = "val" +i64_range = [1, 1] + +[[database_writers]] +agents = [{name = "foo", sampling_interval = "10s"}] +"#; + let data_spec = DataSpec::from_str(toml).unwrap(); + let agent_spec = &data_spec.agents[0]; + + let execution_start_time = now_ns(); + + // imagine we've specified at the command line that we want to generate metrics + // for 1970 + let start_datetime = Some(0); + // for the first 15 seconds of the year + let end_datetime = Some(15 * 1_000_000_000); + + let generated_tag_sets = GeneratedTagSets::default(); + + let mut agent = agent::Agent::from_spec( + agent_spec, + 1, + Duration::from_secs(10), + start_datetime, + end_datetime, + execution_start_time, + false, + &generated_tag_sets, + )?; + + let data_points = agent[0].generate().await?.into_iter().flatten(); + let mut v = Vec::new(); + for data_point in data_points { + data_point.write_data_point_to(&mut v).unwrap(); + } + let line_protocol = String::from_utf8(v).unwrap(); + + // Get a point for time 0 + let expected_line_protocol = "cpu val=1i 0\n"; + assert_eq!(line_protocol, expected_line_protocol); + + let data_points = agent[0].generate().await?.into_iter().flatten(); + let mut v = Vec::new(); + for data_point in data_points { + data_point.write_data_point_to(&mut v).unwrap(); + } + let line_protocol = String::from_utf8(v).unwrap(); + + // Get a point for time 10s + let expected_line_protocol = "cpu val=1i 10000000000\n"; + assert_eq!(line_protocol, expected_line_protocol); + + // Don't get any points anymore because we're past the ending datetime + let data_points = agent[0].generate().await?.into_iter().flatten(); + let data_points: Vec<_> = data_points.collect(); + assert!( + data_points.is_empty(), + "expected no data points, got {data_points:?}" + ); + + Ok(()) + } +} diff --git a/iox_data_generator/src/measurement.rs b/iox_data_generator/src/measurement.rs new file mode 100644 index 0000000..a354404 --- /dev/null +++ b/iox_data_generator/src/measurement.rs @@ -0,0 +1,661 @@ +//! Generating a set of points for one measurement configuration + +#![allow(clippy::result_large_err)] + +use crate::{ + field::FieldGeneratorImpl, + specification, substitution, + tag_pair::TagPair, + tag_set::{GeneratedTagSets, TagSet}, +}; +use influxdb2_client::models::WriteDataPoint; +use serde_json::json; +use snafu::{OptionExt, ResultExt, Snafu}; +use std::{ + fmt::Debug, + sync::{Arc, Mutex}, +}; + +/// Measurement-specific Results +pub type Result = std::result::Result; + +/// Errors that may happen while creating measurements +#[derive(Snafu, Debug)] +#[allow(missing_docs)] +pub enum Error { + #[snafu(display( + "Could not build data point for measurement `{}` with Influx Client, caused by:\n{}", + name, + source + ))] + InfluxDataPointError { + name: String, + source: influxdb2_client::models::data_point::DataPointError, + }, + + #[snafu(display("Could not create measurement name, caused by:\n{}", source))] + CouldNotCreateMeasurementName { source: crate::substitution::Error }, + + #[snafu(display( + "Could not create field generator sets for measurement `{}`, caused by:\n{}", + name, + source + ))] + CouldNotCreateFieldGeneratorSets { + name: String, + source: crate::field::Error, + }, + + #[snafu(display( + "Tag set {} referenced not found for measurement {}", + tag_set, + measurement + ))] + GeneratedTagSetNotFound { + tag_set: String, + measurement: String, + }, + + #[snafu(display("Could not compile template `{}`, caused by:\n{}", template, source))] + CantCompileTemplate { + source: handlebars::TemplateError, + template: String, + }, + + #[snafu(display("Could not render template `{}`, caused by:\n{}", template, source))] + CantRenderTemplate { + source: handlebars::RenderError, + template: String, + }, + + #[snafu(display("Error creating measurement tag pairs: {}", source))] + CouldNotCreateMeasurementTagPairs { source: crate::tag_pair::Error }, +} + +/// Generate measurements +#[derive(Debug)] +pub struct MeasurementGenerator { + measurement: Arc>, +} + +impl MeasurementGenerator { + /// Create the count specified number of measurement generators from + /// the passed `MeasurementSpec` + pub fn from_spec( + agent_id: usize, + spec: &specification::MeasurementSpec, + execution_start_time: i64, + generated_tag_sets: &GeneratedTagSets, + agent_tag_pairs: &[Arc], + ) -> Result> { + let count = spec.count.unwrap_or(1) + 1; + + (1..count) + .map(|measurement_id| { + Self::new( + agent_id, + measurement_id, + spec, + execution_start_time, + generated_tag_sets, + agent_tag_pairs, + ) + }) + .collect::>>() + } + + /// Create a new way to generate measurements from a specification + #[allow(clippy::too_many_arguments)] + pub fn new( + agent_id: usize, + measurement_id: usize, + spec: &specification::MeasurementSpec, + execution_start_time: i64, + generated_tag_sets: &GeneratedTagSets, + agent_tag_pairs: &[Arc], + ) -> Result { + let measurement_name = substitution::render_once( + "measurement", + &spec.name, + &json!({ + "agent": {"id": agent_id}, + "measurement": {"id": measurement_id}, + }), + ) + .context(CouldNotCreateMeasurementNameSnafu)?; + + let fields = spec + .fields + .iter() + .map(|field_spec| { + let data = json!({ + "agent": {"id": agent_id}, + "measurement": {"id": measurement_id, "name": &measurement_name}, + }); + + FieldGeneratorImpl::from_spec(field_spec, data, execution_start_time) + }) + .collect::>>() + .context(CouldNotCreateFieldGeneratorSetsSnafu { + name: &measurement_name, + })? + .into_iter() + .flatten() + .collect(); + + // generate the tag pairs + let template_data = json!({ + "agent": {"id": agent_id}, + "measurement": {"id": measurement_id, "name": &measurement_name}, + }); + + let mut tag_pairs = TagPair::pairs_from_specs(&spec.tag_pairs, template_data) + .context(CouldNotCreateMeasurementTagPairsSnafu)?; + for t in agent_tag_pairs { + tag_pairs.push(Arc::clone(t)); + } + + let generated_tag_sets = match &spec.tag_set { + Some(t) => Arc::clone(generated_tag_sets.sets_for(t).context( + GeneratedTagSetNotFoundSnafu { + tag_set: t, + measurement: &measurement_name, + }, + )?), + // if there's no generated tag set, just have an empty set as a single row so + // it can be used to generate the single line that will come out of each generation + // for this measurement. + None => Arc::new(vec![TagSet { tags: vec![] }]), + }; + + // I have this gnarly tag ordering construction so that I can keep the pre-generated + // tag sets in their existing vecs without moving them around so that I can have + // many thousands of agents and measurements that use the same tagset without blowing + // up the number of vectors and memory I consume. + let mut tag_ordering: Vec<_> = tag_pairs + .iter() + .enumerate() + .map(|(i, p)| (p.key(), TagOrdering::Pair(i))) + .chain( + generated_tag_sets[0] + .tags + .iter() + .enumerate() + .map(|(i, p)| (p.key.to_string(), TagOrdering::Generated(i))), + ) + .collect(); + tag_ordering.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); + let tag_ordering: Vec<_> = tag_ordering.into_iter().map(|(_, o)| o).collect(); + + Ok(Self { + measurement: Arc::new(Mutex::new(Measurement { + name: measurement_name, + tag_pairs, + generated_tag_sets, + tag_ordering, + fields, + })), + }) + } + + /// Create a line iterator to generate lines for a single sampling + pub fn generate(&mut self, timestamp: i64) -> Result { + Ok(MeasurementLineIterator { + measurement: Arc::clone(&self.measurement), + index: 0, + timestamp, + }) + } +} + +/// Details for the measurement to be generated. Can generate many lines +/// for each sampling. +#[derive(Debug)] +pub struct Measurement { + name: String, + tag_pairs: Vec>, + generated_tag_sets: Arc>, + tag_ordering: Vec, + fields: Vec, +} + +impl Measurement { + /// The number of lines that will be generated for each sampling of this measurement. + pub fn line_count(&self) -> usize { + self.generated_tag_sets.len() + } + + /// Write the specified line as line protocol to the passed in writer. + pub fn write_index_to( + &mut self, + index: usize, + timestamp: i64, + mut w: W, + ) -> std::io::Result<()> { + write!(w, "{}", self.name)?; + let row_tags = &self.generated_tag_sets[index].tags; + for t in &self.tag_ordering { + match t { + TagOrdering::Generated(index) => { + let t = &row_tags[*index]; + write!(w, ",{}={}", t.key, t.value)?; + } + TagOrdering::Pair(index) => { + let t = &self.tag_pairs[*index].as_ref(); + match t { + TagPair::Static(t) => write!(w, ",{}={}", t.key, t.value)?, + TagPair::Regenerating(t) => { + let mut t = t.lock().expect("mutex poisoned"); + let p = t.tag_pair(); + write!(w, ",{}={}", p.key, p.value)? + } + } + } + } + } + + for (i, field) in self.fields.iter_mut().enumerate() { + let d = if i == 0 { b" " } else { b"," }; + w.write_all(d)?; + + match field { + FieldGeneratorImpl::Bool(f) => { + let v = f.generate_value(); + write!(w, "{}={}", f.name, if v { "t" } else { "f" })?; + } + FieldGeneratorImpl::I64(f) => { + let v = f.generate_value(); + write!(w, "{}={}i", f.name, v)?; + } + FieldGeneratorImpl::F64(f) => { + let v = f.generate_value(); + write!(w, "{}={}", f.name, v)?; + } + FieldGeneratorImpl::String(f) => { + let v = f.generate_value(timestamp); + write!(w, "{}=\"{}\"", f.name, v)?; + } + FieldGeneratorImpl::Uptime(f) => match f.kind { + specification::UptimeKind::I64 => { + let v = f.generate_value(); + write!(w, "{}={}i", f.name, v)?; + } + specification::UptimeKind::Telegraf => { + let v = f.generate_value_as_string(); + write!(w, "{}=\"{}\"", f.name, v)?; + } + }, + } + } + + writeln!(w, " {timestamp}") + } +} + +#[derive(Debug)] +enum TagOrdering { + Pair(usize), + Generated(usize), +} + +/// Iterator to generate the lines for a given measurement +#[derive(Debug)] +pub struct MeasurementLineIterator { + measurement: Arc>, + index: usize, + timestamp: i64, +} + +impl MeasurementLineIterator { + /// Number of lines that will be generated for this measurement + pub fn line_count(&self) -> usize { + let m = self.measurement.lock().expect("mutex poinsoned"); + m.line_count() + } +} + +impl Iterator for MeasurementLineIterator { + type Item = LineToGenerate; + + /// Get the details for the next `LineToGenerate` + fn next(&mut self) -> Option { + let m = self.measurement.lock().expect("mutex poinsoned"); + + if self.index >= m.line_count() { + None + } else { + let n = Some(LineToGenerate { + measurement: Arc::clone(&self.measurement), + index: self.index, + timestamp: self.timestamp, + }); + self.index += 1; + n + } + } +} + +/// A pointer to the line to be generated. Will be evaluated when asked to write. +#[derive(Debug)] +pub struct LineToGenerate { + /// The measurement state to be used to generate the line + pub measurement: Arc>, + /// The index into the generated tag pairs of the line we're generating + pub index: usize, + /// The timestamp of the line that we're generating + pub timestamp: i64, +} + +impl WriteDataPoint for LineToGenerate { + /// Generate the data and write the line to the passed in writer. + fn write_data_point_to(&self, w: W) -> std::io::Result<()> + where + W: std::io::Write, + { + let mut m = self.measurement.lock().expect("mutex poisoned"); + m.write_index_to(self.index, self.timestamp, w) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::specification::*; + use influxdb2_client::models::WriteDataPoint; + use std::str; + + type Error = Box; + type Result = std::result::Result; + + impl MeasurementGenerator { + fn generate_string(&mut self, timestamp: i64) -> Result { + self.generate_strings(timestamp) + .map(|mut strings| strings.swap_remove(0)) + } + + fn generate_strings(&mut self, timestamp: i64) -> Result> { + let points = self.generate(timestamp)?; + points + .into_iter() + .map(|point| { + let mut v = Vec::new(); + point.write_data_point_to(&mut v)?; + Ok(String::from_utf8(v)?) + }) + .collect() + } + } + + #[test] + fn generate_measurement() -> Result { + let fake_now = 5678; + + // This is the same as the previous test but with an additional field. + let measurement_spec = MeasurementSpec { + name: "cpu".into(), + count: Some(2), + fields: vec![ + FieldSpec { + name: "load".into(), + field_value_spec: FieldValueSpec::F64 { range: 0.0..100.0 }, + count: None, + }, + FieldSpec { + name: "response_time".into(), + field_value_spec: FieldValueSpec::I64 { + range: 0..60_000, + increment: false, + reset_after: None, + }, + count: None, + }, + ], + tag_set: None, + tag_pairs: vec![], + }; + + let generated_tag_sets = GeneratedTagSets::default(); + + let mut measurement_generator = + MeasurementGenerator::new(0, 0, &measurement_spec, fake_now, &generated_tag_sets, &[]) + .unwrap(); + + let line_protocol = vec![measurement_generator.generate_string(fake_now)?]; + let response_times = extract_field_values("response_time", &line_protocol); + + let next_line_protocol = vec![measurement_generator.generate_string(fake_now + 1)?]; + let next_response_times = extract_field_values("response_time", &next_line_protocol); + + // Each line should have a different response time unless we get really, really unlucky + assert_ne!(response_times, next_response_times); + + Ok(()) + } + + #[test] + fn generate_measurement_with_basic_tags() -> Result { + let fake_now = 678; + + let measurement_spec = MeasurementSpec { + name: "measurement".to_string(), + count: None, + tag_set: None, + tag_pairs: vec![ + TagPairSpec { + key: "some_name".to_string(), + template: "some_value".to_string(), + count: None, + regenerate_after_lines: None, + }, + TagPairSpec { + key: "tag_name".to_string(), + template: "tag_value".to_string(), + count: None, + regenerate_after_lines: None, + }, + ], + fields: vec![FieldSpec { + name: "field_name".to_string(), + field_value_spec: FieldValueSpec::I64 { + range: 1..1, + increment: false, + reset_after: None, + }, + count: None, + }], + }; + let generated_tag_sets = GeneratedTagSets::default(); + + let mut measurement_generator = + MeasurementGenerator::new(0, 0, &measurement_spec, fake_now, &generated_tag_sets, &[]) + .unwrap(); + + let line_protocol = measurement_generator.generate_string(fake_now)?; + + assert_eq!( + line_protocol, + format!( + "measurement,some_name=some_value,tag_name=tag_value field_name=1i {fake_now}\n" + ) + ); + + Ok(()) + } + + #[test] + fn generate_measurement_with_tags_with_count() { + let fake_now = 678; + + let measurement_spec = MeasurementSpec { + name: "measurement".to_string(), + count: None, + tag_set: None, + tag_pairs: vec![TagPairSpec { + key: "some_name".to_string(), + template: "some_value {{id}}".to_string(), + count: Some(2), + regenerate_after_lines: None, + }], + fields: vec![FieldSpec { + name: "field_name".to_string(), + field_value_spec: FieldValueSpec::I64 { + range: 1..1, + increment: false, + reset_after: None, + }, + count: None, + }], + }; + let generated_tag_sets = GeneratedTagSets::default(); + + let mut measurement_generator = + MeasurementGenerator::new(0, 0, &measurement_spec, fake_now, &generated_tag_sets, &[]) + .unwrap(); + + let line_protocol = measurement_generator.generate_string(fake_now).unwrap(); + + assert_eq!( + line_protocol, + format!( + "measurement,some_name=some_value 1,some_name2=some_value 2 field_name=1i {fake_now}\n" + ) + ); + } + + #[test] + fn regenerating_after_lines() { + let data_spec: specification::DataSpec = toml::from_str( + r#" + name = "ex" + + [[values]] + name = "foo" + template = "{{id}}" + cardinality = 3 + + [[tag_sets]] + name = "foo_set" + for_each = ["foo"] + + [[agents]] + name = "foo" + + [[agents.measurements]] + name = "m1" + tag_set = "foo_set" + tag_pairs = [{key = "reg", template = "data-{{line_number}}", regenerate_after_lines = 2}] + + [[agents.measurements.fields]] + name = "val" + i64_range = [3, 3] + + [[database_writers]] + agents = [{name = "foo", sampling_interval = "10s"}]"#, + ) + .unwrap(); + + let fake_now = 678; + + let generated_tag_sets = GeneratedTagSets::from_spec(&data_spec).unwrap(); + + let mut measurement_generator = MeasurementGenerator::new( + 42, + 1, + &data_spec.agents[0].measurements[0], + fake_now, + &generated_tag_sets, + &[], + ) + .unwrap(); + + let points = measurement_generator.generate(fake_now).unwrap(); + let mut v = Vec::new(); + for point in points { + point.write_data_point_to(&mut v).unwrap(); + } + let line_protocol = str::from_utf8(&v).unwrap(); + + assert_eq!( + line_protocol, + format!( + "m1,foo=1,reg=data-1 val=3i {fake_now}\nm1,foo=2,reg=data-1 val=3i {fake_now}\nm1,foo=3,reg=data-3 val=3i {fake_now}\n" + ) + ); + } + + #[test] + fn tag_set_and_tag_pairs() { + let data_spec: specification::DataSpec = toml::from_str( + r#" + name = "ex" + + [[values]] + name = "foo" + template = "foo-{{id}}" + cardinality = 2 + + [[tag_sets]] + name = "foo_set" + for_each = ["foo"] + + [[agents]] + name = "foo" + + [[agents.measurements]] + name = "m1" + tag_set = "foo_set" + tag_pairs = [{key = "hello", template = "world{{measurement.id}}"}] + + [[agents.measurements.fields]] + name = "val" + i64_range = [3, 3] + + [[database_writers]] + database_ratio = 1.0 + agents = [{name = "foo", sampling_interval = "10s"}]"#, + ) + .unwrap(); + + let fake_now = 678; + + let generated_tag_sets = GeneratedTagSets::from_spec(&data_spec).unwrap(); + + let mut measurement_generator = MeasurementGenerator::new( + 42, + 1, + &data_spec.agents[0].measurements[0], + fake_now, + &generated_tag_sets, + &[], + ) + .unwrap(); + + let points = measurement_generator.generate(fake_now).unwrap(); + let mut v = Vec::new(); + for point in points { + point.write_data_point_to(&mut v).unwrap(); + } + let line_protocol = str::from_utf8(&v).unwrap(); + + assert_eq!( + line_protocol, + format!( + "m1,foo=foo-1,hello=world1 val=3i {fake_now}\nm1,foo=foo-2,hello=world1 val=3i {fake_now}\n" + ) + ); + } + + fn extract_field_values<'a>(field_name: &str, lines: &'a [String]) -> Vec<&'a str> { + lines + .iter() + .map(|line| { + let mut split = line.splitn(2, ' '); + split.next(); + let after_space = split.next().unwrap(); + let prefix = format!(",{field_name}="); + let after = after_space.rsplit_once(&prefix).unwrap().1; + after.split_once(',').map_or(after, |x| x.0) + }) + .collect() + } +} diff --git a/iox_data_generator/src/specification.rs b/iox_data_generator/src/specification.rs new file mode 100644 index 0000000..4b30935 --- /dev/null +++ b/iox_data_generator/src/specification.rs @@ -0,0 +1,900 @@ +//! Reading and interpreting data generation specifications. + +use humantime::parse_duration; +use regex::Regex; +use serde::Deserialize; +use snafu::{OptionExt, ResultExt, Snafu}; +use std::{fs, ops::Range, str::FromStr, sync::Arc, time::Duration}; +use tracing::warn; + +/// Errors that may happen while reading a TOML specification. +#[derive(Snafu, Debug)] +#[allow(missing_docs)] +pub enum Error { + /// File-related error that may happen while reading a specification + #[snafu(display( + r#"Error reading data spec from TOML file at {}: {}"#, + file_name, + source + ))] + ReadFile { + file_name: String, + /// Underlying I/O error that caused this problem + source: std::io::Error, + }, + + /// TOML parsing error that may happen while interpreting a specification + #[snafu(display(r#"Error parsing data spec from TOML: {}"#, source))] + Parse { + /// Underlying TOML error that caused this problem + source: toml::de::Error, + }, + + #[snafu(display("Sampling interval must be valid string: {}", source))] + InvalidSamplingInterval { source: humantime::DurationError }, + + #[snafu(display( + "Agent {} referenced in database_writers, but not present in spec", + agent + ))] + AgentNotFound { agent: String }, + + #[snafu(display("database_writers can only use database_ratio or database_regex, not both"))] + DatabaseWritersConfig, + + #[snafu(display( + "database_writer missing database_regex. If one uses a regex, all others must also use it" + ))] + RegexMissing, + + #[snafu(display("database_writers regex {} failed with error: {}", regex, source))] + RegexCompile { regex: String, source: regex::Error }, +} + +type Result = std::result::Result; + +/// The full specification for the generation of a data set. +#[derive(Deserialize, Debug)] +#[serde(deny_unknown_fields)] +pub struct DataSpec { + /// This name can be referenced in handlebars templates as `{{spec_name}}` + pub name: String, + /// Specifies values that are generated before agents are created. These values + /// can be used in tag set specs, which will pre-create tag sets that can then be + /// used by the agent specs. + #[serde(default)] + pub values: Vec, + /// Specifies collections of tag sets that can be referenced by agents. These + /// pre-generated tag sets are an efficient way to have many tags without + /// re-rendering their values on every agent generation. They can also have + /// dependent values, making it easy to create high cardinality data sets + /// without running through many handlebar renders while having a well defined + /// set of tags that appear. + #[serde(default)] + pub tag_sets: Vec, + /// The specification for the agents that can be used to write data to databases. + pub agents: Vec, + /// The specification for writing to the provided list of databases. + pub database_writers: Vec, +} + +impl DataSpec { + /// Given a filename, read the file and parse the specification. + pub fn from_file(file_name: &str) -> Result { + let spec_toml = fs::read_to_string(file_name).context(ReadFileSnafu { file_name })?; + Self::from_str(&spec_toml) + } + + /// Given a collection of database names, assign each a set of agents based on the spec + pub fn database_split_to_agents<'a>( + &'a self, + databases: &'a [String], + ) -> Result>> { + let mut database_agents = Vec::with_capacity(databases.len()); + + let mut start = 0; + + // either all database writers must use regex or none of them can. It's either ratio or + // regex for assignment + let use_ratio = self.database_writers[0].database_regex.is_none(); + for b in &self.database_writers { + if use_ratio && b.database_regex.is_some() { + return DatabaseWritersConfigSnafu.fail(); + } + } + + for w in &self.database_writers { + let agents: Vec<_> = w + .agents + .iter() + .map(|a| { + let count = a.count.unwrap_or(1); + let sampling_interval = parse_duration(&a.sampling_interval) + .context(InvalidSamplingIntervalSnafu)?; + let spec = self + .agent_by_name(&a.name) + .context(AgentNotFoundSnafu { agent: &a.name })?; + + Ok(AgentAssignment { + spec, + count, + sampling_interval, + }) + }) + .collect::>>()?; + let agents = Arc::new(agents); + + let selected_databases = if use_ratio { + if start >= databases.len() { + warn!( + "database_writers percentages > 1.0. Writer {:?} and later skipped.", + w + ); + break; + } + + let mut end = (databases.len() as f64 * w.database_ratio.unwrap_or(1.0)).ceil() + as usize + + start; + if end > databases.len() { + end = databases.len(); + } + + let selected_databases = databases[start..end].iter().collect::>(); + start = end; + selected_databases + } else { + let p = w.database_regex.as_ref().context(RegexMissingSnafu)?; + let re = Regex::new(p).context(RegexCompileSnafu { regex: p })?; + databases + .iter() + .filter(|name| re.is_match(name)) + .collect::>() + }; + + for database in selected_databases { + database_agents.push(DatabaseAgents { + database, + agent_assignments: Arc::clone(&agents), + }) + } + } + + Ok(database_agents) + } + + /// Get the agent spec by its name + pub fn agent_by_name(&self, name: &str) -> Option<&AgentSpec> { + self.agents.iter().find(|&a| a.name == name) + } +} + +#[derive(Debug)] +/// Assignment info for an agent to a database +pub struct AgentAssignment<'a> { + /// The agent specification for writing to the assigned database + pub spec: &'a AgentSpec, + /// The number of these agents that should be writing to the database + pub count: usize, + /// The sampling interval agents will generate data on + pub sampling_interval: Duration, +} + +#[derive(Debug)] +/// Agent assignments mapped to a database +pub struct DatabaseAgents<'a> { + /// The database data will get written to + pub database: &'a str, + /// The agents specifications that will be writing to the database + pub agent_assignments: Arc>>, +} + +impl FromStr for DataSpec { + type Err = Error; + + fn from_str(spec_toml: &str) -> std::result::Result::Err> { + let spec: Self = toml::from_str(spec_toml).context(ParseSnafu)?; + Ok(spec) + } +} + +/// The specification of values that can be used to generate tag sets +#[derive(Deserialize, Debug, Clone)] +#[cfg_attr(test, derive(Default))] +#[serde(deny_unknown_fields)] +pub struct ValuesSpec { + /// The name of the collection of values + pub name: String, + /// If values not specified this handlebars template will be used to create each value in the + /// collection + pub template: String, + /// How many of these values should be generated. If belongs_to is + /// specified, each parent will have this many of this value. So + /// the total number of these values generated would be parent.len() * self.cardinality + pub cardinality: usize, + /// A collection of strings to other values. Each one of these values will have one + /// of the referenced has_one. Further, when generating this, the has_one collection + /// will cycle through so that each successive value will use the next has_one value + /// for association + pub has_one: Option>, + /// A collection of values that each of these values belongs to. These relationships + /// can be referenced in the value generation and in the generation of tag sets. + pub belongs_to: Option, +} + +impl ValuesSpec { + /// returns true if there are other value collections that this values spec must use to + /// be generated + pub fn has_dependent_values(&self) -> bool { + self.has_one.is_some() || self.belongs_to.is_some() + } +} + +/// The specification of tag sets that can be referenced in measurements to pull a pre-generated +/// set of tags in. +#[derive(Deserialize, Debug)] +#[cfg_attr(test, derive(Default))] +#[serde(deny_unknown_fields)] +pub struct TagSetsSpec { + /// The name of the tag set spec + pub name: String, + /// An array of the `ValuesSpec` to loop through. To reference parent belongs_to or has_one + /// values, the parent should come first and then the has_one or child next. Each successive + /// entry in this array is a nested loop. Multiple has_one and a belongs_to on a parent can + /// be traversed. + pub for_each: Vec, +} + +/// The specification for what should be written to the list of provided databases. +/// Databases will be written to by one or more agents with the given sampling interval and +/// agent count. +#[derive(Deserialize, Debug)] +#[serde(deny_unknown_fields)] +pub struct DatabaseWriterSpec { + /// The ratio of databases from the provided list that should use these agents. The + /// ratios of the collection of database_writer specs should add up to 1.0. If ratio + /// is not provided it will default to 1.0 (useful for when you specify only a single + /// database_writer. + /// + /// The interval over the provided list of databases is the cumulative sum of the + /// previous ratios to this ratio. So if you have 10 input databases and 3 database_writers + /// with ratios (in order) of `[0.2, 0.4, and 0.6]` you would have the input list of + /// 10 databases split into these three based on their index in the list: `[0, 1]`, + /// `[2, 5]`, and `[6, 9]`. The first 2 databases, then the next 4, then the remaining 6. + /// + /// The list isn't shuffled as ratios are applied. + pub database_ratio: Option, + /// Regex to select databases from the provided list. If regex is used in any one + /// of the database_writers, it must be used for all of them. + pub database_regex: Option, + /// The agents that should be used to write to these databases. + pub agents: Vec, +} + +/// The specification for the specific configuration of how an agent should write to a database. +#[derive(Deserialize, Debug, Clone)] +#[serde(deny_unknown_fields)] +pub struct AgentAssignmentSpec { + /// The name of the `AgentSpec` to use + pub name: String, + /// The number of these agents that should write to the database + pub count: Option, + /// How frequently each agent will write to the database. This is applicable when using the + /// --continue flag. Otherwise, if doing historical backfill, timestamps of generated data + /// will be this far apart and data will be written in as quickly as possible. + pub sampling_interval: String, +} + +/// The specification of the behavior of an agent, the entity responsible for +/// generating a number of data points according to its configuration. +#[derive(Deserialize, Debug)] +#[cfg_attr(test, derive(Default))] +#[serde(deny_unknown_fields)] +pub struct AgentSpec { + /// The name of the agent, which can be referenced in templates with `agent.name`. + pub name: String, + /// The specifications for the measurements for the agent to generate. + pub measurements: Vec, + /// A collection of strings that reference other `Values` collections. Each agent will have one + /// of the referenced has_one. Further, when generating this, the has_one collection + /// will cycle through so that each successive agent will use the next has_one value + /// for association + #[serde(default)] + pub has_one: Vec, + /// Specification of tag key/value pairs that get generated once and reused for + /// every sampling. Every measurement (and thus line) will have these tag pairs added onto it. + /// The template can use `{{agent.id}}` to reference the agent's id and `{{guid}}` or + /// `{{random N}}` to generate random strings. + #[serde(default)] + pub tag_pairs: Vec, +} + +/// The specification of how to generate data points for a particular +/// measurement. +#[derive(Deserialize, Debug)] +#[cfg_attr(test, derive(Default))] +#[serde(deny_unknown_fields)] +pub struct MeasurementSpec { + /// Name of the measurement. Can be a plain string or a string with + /// placeholders for: + /// + /// - `{{agent.id}}` - the agent ID + /// - `{{measurement.id}}` - the measurement's ID, which must be used if + /// `count` > 1 so that unique measurement names are created + pub name: String, + /// The number of measurements with this configuration that should be + /// created. Default value is 1. If specified, use `{{id}}` + /// in this measurement's `name` to create unique measurements. + pub count: Option, + /// Specifies a tag set to include in every sampling in addition to tags specified + pub tag_set: Option, + /// Specification of tag key/value pairs that get generated once and reused for + /// every sampling. + #[serde(default)] + pub tag_pairs: Vec, + /// Specification of the fields for this measurement. At least one field is + /// required. + pub fields: Vec, +} + +/// Specification of a tag key/value pair whose template will be evaluated once and +/// the value will be reused across every sampling. +#[derive(Deserialize, Debug, Clone)] +#[cfg_attr(test, derive(Default))] +#[serde(deny_unknown_fields)] +pub struct TagPairSpec { + /// The tag key. If `count` is specified, the id of the tag will be automatically + /// appended to the end of the key to ensure it is unique. + pub key: String, + /// The template to generate the tag value + pub template: String, + /// If specified, this number of tags will be generated with this template. Each will + /// have a key of `key#` where # is the number. Useful for creating a degenerate case + /// of having dozens or hundreds of tags + pub count: Option, + /// If specified, the tag template will be re-evaluated after this many lines have been + /// generated. This will go across samplings. For example, if you have this set to 3 and + /// each sample generates two lines, it will get regenerated after the first line in the + /// second sample. This is useful for simulating things like tracing use cases or ephemeral + /// identifiers like process or container IDs. The template has access to the normal data + /// accessible as well as `line_number`. + pub regenerate_after_lines: Option, +} + +/// The specification of how to generate field keys and values for a particular +/// measurement. +#[derive(Deserialize, Debug)] +#[cfg_attr(test, derive(Default))] +#[serde(from = "FieldSpecIntermediate")] +pub struct FieldSpec { + /// Key/name for this field. Can be a plain string or a string with + /// placeholders for: + /// + /// - `{{agent.id}}` - the agent ID + /// - `{{measurement.id}}` - the measurement ID + /// - `{{field.id}}` - the field ID, which must be used if `count` > 1 so + /// that unique field names are created + pub name: String, + /// Specification for the value for this field. + pub field_value_spec: FieldValueSpec, + /// How many fields with this configuration should be created + pub count: Option, +} + +impl From for FieldSpec { + fn from(value: FieldSpecIntermediate) -> Self { + let field_value_spec = if let Some(b) = value.bool { + FieldValueSpec::Bool(b) + } else if let Some((start, end)) = value.i64_range { + FieldValueSpec::I64 { + range: (start..end), + increment: value.increment.unwrap_or(false), + reset_after: value.reset_after, + } + } else if let Some((start, end)) = value.f64_range { + FieldValueSpec::F64 { + range: (start..end), + } + } else if let Some(pattern) = value.template { + FieldValueSpec::String { + pattern, + replacements: value.replacements, + } + } else if let Some(kind) = value.uptime { + FieldValueSpec::Uptime { kind } + } else { + panic!( + "Can't tell what type of field value you're trying to specify with this \ + configuration: `{value:?}" + ); + }; + + Self { + name: value.name, + field_value_spec, + count: value.count, + } + } +} + +/// The specification of a field value of a particular type. Instances should be +/// created by converting a `FieldSpecIntermediate`, which more closely matches +/// the TOML structure. +#[derive(Debug, PartialEq)] +pub enum FieldValueSpec { + /// Configuration of a boolean field. + Bool(bool), + /// Configuration of an integer field. + I64 { + /// The `Range` in which random integer values will be generated. If the + /// range only contains one value, all instances of this field + /// will have the same value. + range: Range, + /// When set to true, after an initial random value in the range is + /// generated, a random increment in the range will be generated + /// and added to the initial value. That means the + /// value for this field will always be increasing. When the value + /// reaches the max value of i64, the value will wrap around to + /// the min value of i64 and increment again. + increment: bool, + /// If `increment` is true, after this many samples, reset the value to + /// start the increasing value over. If this is `None`, the + /// value won't restart until reaching the max value of i64. If + /// `increment` is false, this has no effect. + reset_after: Option, + }, + /// Configuration of a floating point field. + F64 { + /// The `Range` in which random floating point values will be generated. + /// If start == end, all instances of this field will have the + /// same value. + range: Range, + }, + /// Configuration of a string field. + String { + /// Pattern containing placeholders that specifies how to generate the + /// string values. + /// + /// Valid placeholders include: + /// + /// - `{{agent_name}}` - the agent spec's name, with any replacements + /// done + /// - `{{time}}` - the current time in nanoseconds since the epoch. + /// TODO: support specifying a strftime + /// - any other placeholders as specified in `replacements`. If a + /// placeholder has no value specified in `replacements`, it will end + /// up as-is in the field value. + pattern: String, + /// A list of replacement placeholders and the values to replace them + /// with. The values can optionally have weights associated with + /// them to change the probabilities that its value + /// will be used. + replacements: Vec, + }, + /// Configuration of a field with the value of the number of seconds the + /// data generation tool has been running. + Uptime { + /// Format of the uptime value in this field + kind: UptimeKind, + }, +} + +/// The kind of field value to create using the data generation tool's uptime +#[derive(Debug, PartialEq, Eq, Copy, Clone, Deserialize)] +pub enum UptimeKind { + /// Number of seconds since the tool started running as an i64 field + #[serde(rename = "i64")] + I64, + /// Number of seconds since the tool started running, formatted as a string + /// field containing the value in the format "x day(s), HH:MM" + #[serde(rename = "telegraf")] + Telegraf, +} + +#[cfg(test)] +impl Default for FieldValueSpec { + fn default() -> Self { + Self::Bool(true) + } +} + +/// An intermediate representation of the field specification that more directly +/// corresponds to the way field configurations are expressed in TOML. This +/// structure is transformed into the `FieldValueSpec` enum that ensures the +/// options for the different field value types are mutually exclusive. +#[derive(Deserialize, Debug)] +#[serde(deny_unknown_fields)] +struct FieldSpecIntermediate { + /// Key/name for this field. Can be a plain string or a string with + /// placeholders for: + /// + /// - `{{agent_id}}` - the agent ID + /// - `{{measurement_id}}` - the measurement ID + /// - `{{field_id}}` - the field ID, which must be used if `count` > 1 so + /// that unique field names are created + name: String, + /// The number of fields with this configuration that should be created. + /// Default value is 1. If specified, use `{{field_id}}` in this field's + /// `name` to create unique fields. + count: Option, + /// Specify `bool` to make a field that has the Boolean type. `true` means + /// to generate the boolean randomly with equal probability. `false` + /// means...? Specifying any other optional fields along with this one + /// is invalid. + bool: Option, + /// Specify `i64_range` to make an integer field. The values will be + /// randomly generated within the specified range with equal + /// probability. If the range only contains one element, all occurrences + /// of this field will have the same value. Can be combined with + /// `increment`; specifying any other optional fields is invalid. + i64_range: Option<(i64, i64)>, + /// Specify `f64_range` to make a floating point field. The values will be + /// randomly generated within the specified range. If start == end, all + /// occurrences of this field will have that value. + /// Can this be combined with `increment`? + f64_range: Option<(f64, f64)>, + /// When set to true with an `i64_range` (is this valid with any other + /// type?), after an initial random value is generated, a random + /// increment will be generated and added to the initial value. That + /// means the value for this field will always be increasing. When the value + /// reaches the end of the range...? The end of the range will be repeated + /// forever? The series will restart at the start of the range? + /// Something else? Setting this to `Some(false)` has the same effect as + /// `None`. + increment: Option, + /// If `increment` is true, after this many samples, reset the value to + /// start the increasing value over. If this is `None`, the value won't + /// restart until reaching the max value of i64. If `increment` is + /// false, this has no effect. + reset_after: Option, + /// Set `pattern` to make a field with the string type. If this doesn't + /// include any placeholders, all occurrences of this field will have + /// this value. + /// + /// Valid placeholders include: + /// + /// - `{{agent.id}}` - the agent spec's name, with any replacements done + /// - any other placeholders as specified in `replacements`. If a + /// placeholder has no value specified in `replacements`, it will end up + /// as-is in the field value. + template: Option, + /// A list of replacement placeholders and the values to replace them with. + /// If a placeholder specified here is not used in `pattern`, it will + /// have no effect. The values may optionally have a probability weight + /// specified with them; if not specified, the value will have weight 1. + /// If no weights are specified, the values will be generated with equal + /// probability. + #[serde(default)] + replacements: Vec, + /// The kind of uptime that should be used for this field. If specified, no + /// other options are valid. If not specified, this is not an uptime + /// field. + uptime: Option, +} + +/// The specification of what values to substitute in for placeholders specified +/// in `String` field values. +#[derive(Deserialize, Debug, PartialEq, Eq, Clone)] +#[serde(deny_unknown_fields)] +pub struct Replacement { + /// A placeholder key that can be used in field `pattern`s. + pub replace: String, + /// The possible values to use instead of the placeholder key in `pattern`. + /// Values may optionally have a weight specified. If no weights are + /// specified, the values will be randomly generated with equal + /// probability. The weights are passed to [`rand`'s `choose_weighted` + /// method][choose_weighted] and are a relative likelihood such that the + /// probability of each item being selected is its weight divided by the sum + /// of all weights in this group. + /// + /// [choose_weighted]: https://docs.rs/rand/0.7.3/rand/seq/trait.SliceRandom.html#tymethod.choose_weighted + pub with: Vec, +} + +#[derive(Debug, Deserialize, PartialEq, Eq, Clone)] +#[serde(untagged, deny_unknown_fields)] +/// A possible value to use instead of a placeholder key, optionally with an +/// associated weight. If no weight is specified, the weight used will be 1. +pub enum ReplacementValue { + /// Just a value without a weight + String(String), + /// A value with a specified relative likelihood weight that gets passed on + /// to [`rand`'s `choose_weighted` method][choose_weighted]. The + /// probability of each item being selected is its weight divided by the + /// sum of all weights in the `Replacement` group. + /// + /// [choose_weighted]: https://docs.rs/rand/0.7.3/rand/seq/trait.SliceRandom.html#tymethod.choose_weighted + Weighted(String, u32), +} + +impl ReplacementValue { + /// The associated replacement value + pub fn value(&self) -> &str { + use ReplacementValue::*; + match self { + String(s) => s, + Weighted(s, ..) => s, + } + } + + /// The associated weight value specified; defaults to 1. + pub fn weight(&self) -> u32 { + use ReplacementValue::*; + match self { + String(..) => 1, + Weighted(.., w) => *w, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn sample_schemas_parse() { + let schemas: Vec<&str> = vec![ + include_str!("../schemas/storage_cardinality_example.toml"), + include_str!("../schemas/cap-write.toml"), + include_str!("../schemas/tracing-spec.toml"), + include_str!("../schemas/full_example.toml"), + ]; + + for s in schemas { + if let Err(e) = DataSpec::from_str(s) { + panic!("error {e:?} on\n{s}") + } + } + } + + #[test] + fn not_specifying_vectors_gets_default_empty_vector() { + let toml = r#" +name = "demo_schema" + +[[agents]] +name = "foo" + +[[agents.measurements]] +name = "cpu" + +[[agents.measurements.fields]] +name = "host" +template = "server" + +[[database_writers]] +database_ratio = 1.0 +agents = [{name = "foo", sampling_interval = "10s"}] +"#; + let spec = DataSpec::from_str(toml).unwrap(); + + let agent0 = &spec.agents[0]; + assert!(agent0.tag_pairs.is_empty()); + + let agent0_measurements = &agent0.measurements; + let a0m0 = &agent0_measurements[0]; + assert!(a0m0.tag_pairs.is_empty()); + + let a0m0_fields = &a0m0.fields; + let a0m0f0 = &a0m0_fields[0]; + let field_spec = &a0m0f0.field_value_spec; + + assert!( + matches!( + field_spec, + FieldValueSpec::String { replacements, .. } if replacements.is_empty() + ), + "expected a String field with empty replacements; was {field_spec:?}" + ); + } + + #[test] + fn split_databases_by_writer_spec_ratio() { + let toml = r#" +name = "demo_schema" + +[[agents]] +name = "foo" +[[agents.measurements]] +name = "cpu" +[[agents.measurements.fields]] +name = "host" +template = "server" + +[[agents]] +name = "bar" +[[agents.measurements]] +name = "whatevs" +[[agents.measurements.fields]] +name = "val" +i64_range = [0, 10] + +[[database_writers]] +database_ratio = 0.6 +agents = [{name = "foo", sampling_interval = "10s"}] + +[[database_writers]] +database_ratio = 0.4 +agents = [{name = "bar", sampling_interval = "1m", count = 3}] +"#; + let spec = DataSpec::from_str(toml).unwrap(); + let databases = vec!["a_1".to_string(), "a_2".to_string(), "b_1".to_string()]; + + let database_agents = spec.database_split_to_agents(&databases).unwrap(); + + let b = &database_agents[0]; + assert_eq!(b.database, &databases[0]); + assert_eq!( + b.agent_assignments[0].sampling_interval, + Duration::from_secs(10) + ); + assert_eq!(b.agent_assignments[0].count, 1); + assert_eq!(b.agent_assignments[0].spec.name, "foo"); + + let b = &database_agents[1]; + assert_eq!(b.database, &databases[1]); + assert_eq!( + b.agent_assignments[0].sampling_interval, + Duration::from_secs(10) + ); + assert_eq!(b.agent_assignments[0].count, 1); + assert_eq!(b.agent_assignments[0].spec.name, "foo"); + + let b = &database_agents[2]; + assert_eq!(b.database, &databases[2]); + assert_eq!( + b.agent_assignments[0].sampling_interval, + Duration::from_secs(60) + ); + assert_eq!(b.agent_assignments[0].count, 3); + assert_eq!(b.agent_assignments[0].spec.name, "bar"); + } + + #[test] + fn split_databases_by_writer_spec_regex() { + let toml = r#" +name = "demo_schema" + +[[agents]] +name = "foo" +[[agents.measurements]] +name = "cpu" +[[agents.measurements.fields]] +name = "host" +template = "server" + +[[agents]] +name = "bar" +[[agents.measurements]] +name = "whatevs" +[[agents.measurements.fields]] +name = "val" +i64_range = [0, 10] + +[[database_writers]] +database_regex = "foo.*" +agents = [{name = "foo", sampling_interval = "10s"}] + +[[database_writers]] +database_regex = ".*_bar" +agents = [{name = "bar", sampling_interval = "1m", count = 3}] +"#; + + let spec = DataSpec::from_str(toml).unwrap(); + let databases = vec![ + "foo_1".to_string(), + "foo_2".to_string(), + "asdf_bar".to_string(), + ]; + + let database_agents = spec.database_split_to_agents(&databases).unwrap(); + + let b = &database_agents[0]; + assert_eq!(b.database, &databases[0]); + assert_eq!( + b.agent_assignments[0].sampling_interval, + Duration::from_secs(10) + ); + assert_eq!(b.agent_assignments[0].count, 1); + assert_eq!(b.agent_assignments[0].spec.name, "foo"); + + let b = &database_agents[1]; + assert_eq!(b.database, &databases[1]); + assert_eq!( + b.agent_assignments[0].sampling_interval, + Duration::from_secs(10) + ); + assert_eq!(b.agent_assignments[0].count, 1); + assert_eq!(b.agent_assignments[0].spec.name, "foo"); + + let b = &database_agents[2]; + assert_eq!(b.database, &databases[2]); + assert_eq!( + b.agent_assignments[0].sampling_interval, + Duration::from_secs(60) + ); + assert_eq!(b.agent_assignments[0].count, 3); + assert_eq!(b.agent_assignments[0].spec.name, "bar"); + } + + #[test] + fn split_databases_by_writer_regex_and_ratio_error() { + let toml = r#" +name = "demo_schema" + +[[agents]] +name = "foo" +[[agents.measurements]] +name = "cpu" +[[agents.measurements.fields]] +name = "host" +template = "server" + +[[agents]] +name = "bar" +[[agents.measurements]] +name = "whatevs" +[[agents.measurements.fields]] +name = "val" +i64_range = [0, 10] + +[[database_writers]] +database_ratio = 0.8 +agents = [{name = "foo", sampling_interval = "10s"}] + +[[database_writers]] +database_regex = "foo.*" +agents = [{name = "bar", sampling_interval = "1m", count = 3}] +"#; + + let spec = DataSpec::from_str(toml).unwrap(); + let databases = vec!["a_1".to_string(), "a_2".to_string(), "b_1".to_string()]; + + let database_agents = spec.database_split_to_agents(&databases); + assert!(matches!( + database_agents.unwrap_err(), + Error::DatabaseWritersConfig + )); + } + + #[test] + fn split_databases_by_writer_ratio_defaults() { + let toml = r#" +name = "demo_schema" + +[[agents]] +name = "foo" +[[agents.measurements]] +name = "cpu" +[[agents.measurements.fields]] +name = "host" +template = "server" + +[[database_writers]] +agents = [{name = "foo", sampling_interval = "10s"}] +"#; + + let spec = DataSpec::from_str(toml).unwrap(); + let databases = vec!["a_1".to_string(), "a_2".to_string()]; + + let database_agents = spec.database_split_to_agents(&databases).unwrap(); + + let b = &database_agents[0]; + assert_eq!(b.database, &databases[0]); + assert_eq!( + b.agent_assignments[0].sampling_interval, + Duration::from_secs(10) + ); + assert_eq!(b.agent_assignments[0].count, 1); + assert_eq!(b.agent_assignments[0].spec.name, "foo"); + + let b = &database_agents[1]; + assert_eq!(b.database, &databases[1]); + assert_eq!( + b.agent_assignments[0].sampling_interval, + Duration::from_secs(10) + ); + assert_eq!(b.agent_assignments[0].count, 1); + assert_eq!(b.agent_assignments[0].spec.name, "foo"); + } +} diff --git a/iox_data_generator/src/substitution.rs b/iox_data_generator/src/substitution.rs new file mode 100644 index 0000000..b5e558a --- /dev/null +++ b/iox_data_generator/src/substitution.rs @@ -0,0 +1,241 @@ +//! Substituting dynamic values into a template as specified in various places +//! in the schema. + +use crate::specification; +use chrono::prelude::*; +use handlebars::{ + Context, Handlebars, Helper, HelperDef, HelperResult, Output, RenderContext, RenderErrorReason, +}; +use rand::rngs::SmallRng; +use rand::{distributions::Alphanumeric, seq::SliceRandom, Rng, RngCore}; +use serde_json::Value; +use snafu::{ResultExt, Snafu}; +use std::collections::BTreeMap; + +/// Substitution-specific Results +pub type Result = std::result::Result; + +/// Errors that may happen while substituting values into templates. +#[derive(Snafu, Debug)] +#[allow(missing_docs)] +pub enum Error { + #[snafu(display( + "Could not perform text substitution in `{}`, caused by:\n{}", + template, + source + ))] + CantCompileTemplate { + #[snafu(source(from(handlebars::TemplateError, Box::new)))] + source: Box, + template: String, + }, + + #[snafu(display("Could not render template {}, caused by: {}", name, source))] + CantRenderTemplate { + name: String, + #[snafu(source(from(handlebars::RenderError, Box::new)))] + source: Box, + }, + + #[snafu(display( + "Could not perform text substitution in `{}`, caused by:\n{}", + template, + source + ))] + CantPerformSubstitution { + #[snafu(source(from(handlebars::RenderError, Box::new)))] + source: Box, + template: String, + }, +} + +pub(crate) fn render_once(name: &str, template: impl Into, data: &Value) -> Result { + let mut registry = new_handlebars_registry(); + registry.set_strict_mode(true); + let template = template.into(); + registry + .register_template_string(name, &template) + .context(CantCompileTemplateSnafu { template })?; + registry + .render(name, data) + .context(CantRenderTemplateSnafu { name }) +} + +pub(crate) fn new_handlebars_registry() -> Handlebars<'static> { + let mut registry = Handlebars::new(); + registry.set_strict_mode(true); + registry.register_helper("format-time", Box::new(FormatNowHelper)); + registry.register_helper("random", Box::new(RandomHelper)); + registry.register_helper("guid", Box::new(GuidHelper)); + registry +} + +#[derive(Debug)] +pub(crate) struct RandomHelper; + +impl HelperDef for RandomHelper { + fn call<'reg: 'rc, 'rc>( + &self, + h: &Helper<'_>, + _: &Handlebars<'_>, + _: &Context, + _: &mut RenderContext<'_, '_>, + out: &mut dyn Output, + ) -> HelperResult { + let param = h + .param(0) + .ok_or(RenderErrorReason::ParamNotFoundForIndex("random", 0))? + .value() + .as_u64() + .ok_or_else(|| { + RenderErrorReason::ParamTypeMismatchForName( + "random", + "0".to_string(), + "unsigned integer".to_string(), + ) + })? + .try_into() + .map_err(|_| { + RenderErrorReason::Other("`random`'s parameter must fit in a usize".to_string()) + })?; + + let mut rng = rand::thread_rng(); + + let random: String = std::iter::repeat(()) + .map(|()| rng.sample(Alphanumeric)) + .map(char::from) + .take(param) + .collect(); + + out.write(&random)?; + + Ok(()) + } +} + +#[derive(Debug)] +pub(crate) struct FormatNowHelper; + +impl HelperDef for FormatNowHelper { + fn call<'reg: 'rc, 'rc>( + &self, + h: &Helper<'_>, + _: &Handlebars<'_>, + c: &Context, + _: &mut RenderContext<'_, '_>, + out: &mut dyn Output, + ) -> HelperResult { + let format = h + .param(0) + .ok_or(RenderErrorReason::ParamNotFoundForIndex("format-time", 0))? + .render(); + + let timestamp = c + .data() + .get("timestamp") + .and_then(|t| t.as_i64()) + .expect("Caller of `render` should have set `timestamp` to an `i64` value"); + + let datetime = Utc.timestamp_nanos(timestamp); + + out.write(&datetime.format(&format).to_string())?; + + Ok(()) + } +} + +#[derive(Debug)] +pub(crate) struct GuidHelper; + +impl HelperDef for GuidHelper { + fn call<'reg: 'rc, 'rc>( + &self, + _h: &Helper<'_>, + _: &Handlebars<'_>, + _: &Context, + _: &mut RenderContext<'_, '_>, + out: &mut dyn Output, + ) -> HelperResult { + let mut rng = rand::thread_rng(); + + let mut bytes = [0u8; 16]; + rng.fill_bytes(&mut bytes); + let mut uid_builder = uuid::Builder::from_bytes(bytes); + uid_builder.set_variant(uuid::Variant::RFC4122); + uid_builder.set_version(uuid::Version::Random); + let uid = uid_builder.into_uuid().to_string(); + + out.write(&uid)?; + + Ok(()) + } +} + +/// Given a random number generator and replacement specification, choose a +/// particular value from the list of possible values according to any specified +/// weights (or with equal probability if there are no weights). +pub fn pick_from_replacements<'a>( + rng: &mut SmallRng, + replacements: &'a [specification::Replacement], +) -> BTreeMap<&'a str, &'a str> { + replacements + .iter() + .map(|replacement| { + let chosen = replacement + .with + .choose_weighted(rng, |value| value.weight()) + .expect("`Replacement` `with` should have items") + .value(); + + (replacement.replace.as_str(), chosen) + }) + .collect() +} + +#[cfg(test)] +mod test { + use super::*; + use serde_json::json; + + #[test] + fn format_now_valid_strftime() { + let mut registry = new_handlebars_registry(); + registry + .register_template_string("t", r#"the date is {{format-time "%Y-%m-%d"}}."#) + .unwrap(); + + let timestamp: i64 = 1599154445000000000; + let value = registry + .render("t", &json!({ "timestamp": timestamp })) + .unwrap(); + + assert_eq!(&value, "the date is 2020-09-03."); + } + + #[test] + #[should_panic(expected = "a Display implementation returned an error unexpectedly: Error")] + fn format_now_invalid_strftime_panics() { + let mut registry = new_handlebars_registry(); + registry + .register_template_string("t", r#"the date is {{format-time "%-B"}}."#) + .unwrap(); + + let timestamp: i64 = 1599154445000000000; + let _value = registry + .render("t", &json!({ "timestamp": timestamp })) + .expect("this is unreachable"); + } + + #[test] + fn format_now_missing_strftime() { + let mut registry = new_handlebars_registry(); + registry + .register_template_string("t", r#"the date is {{format-time}}."#) + .unwrap(); + + let timestamp: i64 = 1599154445000000000; + let result = registry.render("t", &json!({ "timestamp": timestamp })); + + assert!(result.is_err()); + } +} diff --git a/iox_data_generator/src/tag_pair.rs b/iox_data_generator/src/tag_pair.rs new file mode 100644 index 0000000..302fc4d --- /dev/null +++ b/iox_data_generator/src/tag_pair.rs @@ -0,0 +1,173 @@ +//! Module for generating tag key/value pairs to be used in the data generator + +use crate::specification::TagPairSpec; +use crate::substitution::new_handlebars_registry; +use handlebars::Handlebars; +use serde_json::{json, Value}; +use snafu::{ResultExt, Snafu}; +use std::fmt::Formatter; +use std::sync::{Arc, Mutex}; + +/// Results specific to the tag_pair module +pub(crate) type Result = std::result::Result; + +/// Errors that may happen while creating or regenerating tag pairs +#[derive(Snafu, Debug)] +pub enum Error { + #[snafu(display( + "Could not compile template for tag pair {} caused by: {}", + tag_key, + source + ))] + CantCompileTemplate { + tag_key: String, + #[snafu(source(from(handlebars::TemplateError, Box::new)))] + source: Box, + }, + + #[snafu(display( + "Could not render template for tag pair {}, cause by: {}", + tag_key, + source + ))] + CantRenderTemplate { + tag_key: String, + #[snafu(source(from(handlebars::RenderError, Box::new)))] + source: Box, + }, +} + +#[derive(Debug)] +pub enum TagPair { + Static(StaticTagPair), + Regenerating(Box>), +} + +impl TagPair { + pub fn pairs_from_specs( + specs: &[TagPairSpec], + mut template_data: Value, + ) -> Result>> { + let tag_pairs: Vec<_> = specs + .iter() + .map(|tag_pair_spec| { + let tag_count = tag_pair_spec.count.unwrap_or(1); + + let tags: Vec<_> = (1..tag_count + 1) + .map(|tag_id| { + let tag_key = if tag_id == 1 { + tag_pair_spec.key.to_string() + } else { + format!("{}{}", tag_pair_spec.key, tag_id) + }; + + let data = template_data.as_object_mut().expect("data must be object"); + data.insert("id".to_string(), json!(tag_id)); + data.insert("line_number".to_string(), json!(1)); + + let mut template = new_handlebars_registry(); + template + .register_template_string(&tag_key, &tag_pair_spec.template) + .context(CantCompileTemplateSnafu { + tag_key: &tag_pair_spec.key, + })?; + + let value = template + .render(&tag_key, &template_data) + .context(CantRenderTemplateSnafu { tag_key: &tag_key })?; + + let tag_pair = StaticTagPair { + key: Arc::new(tag_key), + value: Arc::new(value), + }; + + let tag_pair = if let Some(regenerate_after_lines) = + tag_pair_spec.regenerate_after_lines + { + let regenerating_pair = RegeneratingTagPair { + regenerate_after_lines, + tag_pair, + template, + line_number: 0, + data: template_data.clone(), + }; + + Self::Regenerating(Box::new(Mutex::new(regenerating_pair))) + } else { + Self::Static(tag_pair) + }; + + Ok(Arc::new(tag_pair)) + }) + .collect::>>()?; + + Ok(tags) + }) + .collect::>>()?; + + Ok(tag_pairs.into_iter().flatten().collect()) + } + + pub fn key(&self) -> String { + match self { + Self::Static(p) => p.key.to_string(), + Self::Regenerating(p) => { + let p = p.lock().expect("mutex poisoned"); + p.tag_pair.key.to_string() + } + } + } +} + +/// A tag key/value pair +#[derive(Debug, PartialEq, Eq, PartialOrd, Clone)] +pub struct StaticTagPair { + /// the key + pub key: Arc, + /// the value + pub value: Arc, +} + +impl std::fmt::Display for StaticTagPair { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}={}", self.key, self.value) + } +} + +/// Used for tag pairs specified in either an agent or measurement specification. The +/// spec must be kept around to support regenerating the tag pair. +#[derive(Debug, Clone)] +pub struct RegeneratingTagPair { + regenerate_after_lines: usize, + tag_pair: StaticTagPair, + template: Handlebars<'static>, + data: Value, + line_number: usize, +} + +impl RegeneratingTagPair { + pub fn tag_pair(&mut self) -> &StaticTagPair { + self.line_number += 1; + + if self.should_regenerate() { + let data = self.data.as_object_mut().expect("data must be object"); + data.insert("line_number".to_string(), json!(self.line_number)); + + let value = self + .template + .render(self.tag_pair.key.as_str(), &self.data) + .expect("this template has been rendered before so this shouldn't be possible"); + + self.tag_pair = StaticTagPair { + key: Arc::clone(&self.tag_pair.key), + value: Arc::new(value), + }; + } + + &self.tag_pair + } + + fn should_regenerate(&self) -> bool { + self.line_number % (self.regenerate_after_lines + 1) == 0 + } +} diff --git a/iox_data_generator/src/tag_set.rs b/iox_data_generator/src/tag_set.rs new file mode 100644 index 0000000..a92f4a4 --- /dev/null +++ b/iox_data_generator/src/tag_set.rs @@ -0,0 +1,624 @@ +//! Code for defining values and tag sets with tags that are dependent on other tags. + +use crate::now_ns; +use crate::specification::{DataSpec, ValuesSpec}; +use crate::substitution::new_handlebars_registry; +use crate::tag_pair::StaticTagPair; +use handlebars::Handlebars; +use itertools::Itertools; +use serde_json::json; +use snafu::{OptionExt, ResultExt, Snafu}; +/// Module for pre-generated values and tag sets that can be used when generating samples from +/// agents. +use std::collections::BTreeMap; +use std::fmt::Formatter; +use std::sync::Arc; + +/// Errors that may happen while reading a TOML specification. +#[derive(Snafu, Debug)] +#[allow(missing_docs)] +pub enum Error { + #[snafu(display("{} specifies a has_one member {} that isn't defined", value, has_one))] + HasOneDependencyNotDefined { value: String, has_one: String }, + + /// Error that may happen when compiling a template from the values specification + #[snafu(display("Could not compile template `{}`, caused by:\n{}", template, source))] + CantCompileTemplate { + /// Underlying Handlebars error that caused this problem + #[snafu(source(from(handlebars::TemplateError, Box::new)))] + source: Box, + /// Template that caused this problem + template: String, + }, + + /// Error that may happen when rendering a template with passed in data + #[snafu(display("Could not render template `{}`, caused by:\n{}", template, source))] + CantRenderTemplate { + /// Underlying Handlebars error that caused this problem + #[snafu(source(from(handlebars::RenderError, Box::new)))] + source: Box, + /// Template that caused this problem + template: String, + }, + + #[snafu(display( + "has_one {} must be accessed through its parent (e.g. parent foo with has_one bar: foo.bar", + has_one + ))] + HasOneWithoutParent { has_one: String }, + + #[snafu(display("no has_one found values for {}", has_one))] + HasOneNotFound { has_one: String }, + + #[snafu(display("has_one {} not found for parent id {}", has_one, parent_id))] + HasOneNotFoundForParent { has_one: String, parent_id: usize }, +} + +type Result = std::result::Result; + +/// A single generated value's id and tag key/value pair. +#[derive(Debug)] +pub struct GeneratedValue { + id: usize, + tag_pair: Arc, +} + +/// All generated tag sets specified +#[derive(Debug, Default)] +pub struct GeneratedTagSets { + // These map the name of a collection of values to its values. All values will have + // an entry in this map. For has_one and child_values, they will have duplicates there + // as well to make generating the tag sets possible. + values: BTreeMap>>, + // each parent-child will have its children stored in this map. The children map + // the id of the parent to the collection of its children values + child_values: BTreeMap>>>, + // each parent-has_one will have its has_ones stored in this map + has_one_values: BTreeMap, + // this maps the name of the tag set specified in the spec to the collection of tag + // sets that were pre-generated. + tag_sets: BTreeMap>>, +} + +/// Generated parent to has_one mappings +#[derive(Debug, Default)] +pub struct ParentToHasOnes { + // each parent id will have its has_ones stored in this map. The map within + // maps the has_one name to its generated value + id_to_has_ones: BTreeMap, Arc>>, +} + +impl GeneratedTagSets { + /// Generate tag sets from a `DataSpec` + pub fn from_spec(spec: &DataSpec) -> Result { + let mut generated_tag_sets = Self::default(); + let mut template = new_handlebars_registry(); + + let mut leftover_specs = -1; + + loop { + if leftover_specs == 0 { + break; + } + + let new_leftover = generated_tag_sets.generate_values(&mut template, spec)? as i64; + if new_leftover == leftover_specs { + panic!("unresolvable loop in values generation"); + } + leftover_specs = new_leftover; + } + + generated_tag_sets.generate_tag_sets(spec)?; + + Ok(generated_tag_sets) + } + + /// Returns the tag sets for the given name + pub fn sets_for(&self, name: &str) -> Option<&Arc>> { + self.tag_sets.get(name) + } + + fn generate_values( + &mut self, + registry: &mut Handlebars<'static>, + data_spec: &DataSpec, + ) -> Result { + let mut leftover_count = 0; + + for spec in &data_spec.values { + if self.values.contains_key(&spec.name) { + continue; + } else if !self.can_generate(spec) { + leftover_count += 1; + continue; + } + + self.generate_values_spec(registry, spec)?; + } + + Ok(leftover_count) + } + + fn generate_tag_sets(&mut self, data_spec: &DataSpec) -> Result<()> { + for set_spec in &data_spec.tag_sets { + self.generate_tag_set_spec(&set_spec.name, &set_spec.for_each)?; + } + + Ok(()) + } + + fn generate_tag_set_spec(&mut self, set_name: &str, for_each: &[String]) -> Result<()> { + let mut tag_set_keys: Vec<_> = for_each + .iter() + .map(|v| Key { + name: v.split('.').last().unwrap(), + value: v.to_string(), + position: 0, + }) + .collect(); + + // this weird bit is so that we don't need to sort the tag pairs as we're generating. All + // tag sets here will have the exact same tags and sort order, so do it once and inject tags + // in the appropriate place + let mut sorted_keys: Vec<_> = tag_set_keys.iter_mut().collect(); + sorted_keys.sort_unstable_by(|a, b| a.name.partial_cmp(b.name).unwrap()); + for (pos, k) in sorted_keys.iter_mut().enumerate() { + k.position = pos; + } + + // we pass in a pre-built tag_pairs vec so that we can fill it out as we walk down the for_each + // iteration and then just do a single clone at the very end. + let mut tag_pairs: Vec<_> = (0..for_each.len()) + .map(|_| { + Arc::new(StaticTagPair { + key: Arc::new("default".to_string()), + value: Arc::new("default".to_string()), + }) + }) + .collect(); + let tag_sets = self.for_each_tag_set(None, &tag_set_keys, &mut tag_pairs, 0)?; + self.tag_sets + .insert(set_name.to_string(), Arc::new(tag_sets)); + + Ok(()) + } + + fn for_each_tag_set( + &self, + parent_id: Option, + keys: &[Key<'_>], + tag_pairs: &mut Vec>, + position: usize, + ) -> Result> { + let key = &keys[position]; + + match self.get_generated_values(parent_id, &key.value) { + Some(values) => { + if position == keys.len() - 1 { + let mut tag_sets = Vec::with_capacity(values.len()); + + for v in values { + tag_pairs[key.position] = Arc::clone(&v.tag_pair); + tag_sets.push(TagSet::new(tag_pairs.clone())); + } + + return Ok(tag_sets); + } + + let mut tag_sets = vec![]; + + for v in values { + tag_pairs[key.position] = Arc::clone(&v.tag_pair); + let mut sets = + self.for_each_tag_set(Some(v.id), keys, tag_pairs, position + 1)?; + tag_sets.append(&mut sets); + } + + Ok(tag_sets) + } + None => { + let parent_id = parent_id.expect("for_each_tag_set should never be called without a parent id if in has_one evaluation"); + let one = self + .has_one_values + .get(&key.value) + .context(HasOneNotFoundSnafu { + has_one: &key.value, + })? + .id_to_has_ones + .get(&parent_id) + .context(HasOneNotFoundForParentSnafu { + has_one: &key.value, + parent_id, + })? + .get(&key.value) + .expect("bug in generating values for has_one"); + let tag = Arc::clone(&one.tag_pair); + tag_pairs[key.position] = tag; + + if position == keys.len() - 1 { + Ok(vec![TagSet::new(tag_pairs.clone())]) + } else { + self.for_each_tag_set(Some(parent_id), keys, tag_pairs, position + 1) + } + } + } + } + + fn get_generated_values( + &self, + parent_id: Option, + key: &str, + ) -> Option<&Vec>> { + match self.child_values.get(key) { + Some(child_values) => child_values.get(&parent_id.expect( + "should never get_get_generated_values for child values without a parent_id", + )), + None => self.values.get(key), + } + } + + fn can_generate(&self, spec: &ValuesSpec) -> bool { + match (&spec.has_one, &spec.belongs_to) { + (None, None) => true, + (None, Some(b)) => self.values.contains_key(b), + (Some(has_ones), None) => { + for name in has_ones { + if !self.values.contains_key(name) { + return false; + } + } + + true + } + (Some(has_ones), Some(b)) => { + for name in has_ones { + if !self.values.contains_key(name) { + return false; + } + } + + self.values.contains_key(b) + } + } + } + + fn generate_values_spec( + &mut self, + template: &mut Handlebars<'static>, + spec: &ValuesSpec, + ) -> Result<()> { + template + .register_template_string(&spec.name, &spec.template) + .context(CantCompileTemplateSnafu { + template: &spec.name, + })?; + + match &spec.belongs_to { + Some(belongs_to) => self.generate_belongs_to(template, belongs_to.as_str(), spec)?, + None => { + let mut vals = Vec::with_capacity(spec.cardinality); + let mut id_map = BTreeMap::new(); + let tag_key = Arc::new(spec.name.clone()); + + for i in 1..(spec.cardinality + 1) { + id_map.insert("id", i); + id_map.insert("timestamp", now_ns() as usize); + let rendered_value = + template + .render(&spec.name, &id_map) + .context(CantRenderTemplateSnafu { + template: &spec.name, + })?; + let value = Arc::new(rendered_value); + + vals.push(Arc::new(GeneratedValue { + id: i, + tag_pair: Arc::new(StaticTagPair { + key: Arc::clone(&tag_key), + value, + }), + })); + } + self.values.insert(spec.name.to_string(), vals); + } + } + + if let Some(has_ones) = spec.has_one.as_ref() { + self.add_has_ones(&spec.name, has_ones)?; + } + + Ok(()) + } + + fn add_has_ones(&mut self, parent: &str, has_ones: &[String]) -> Result<()> { + let parent_values = self + .values + .get(parent) + .expect("add_has_ones should never be called before the parent values are inserted"); + + for has_one in has_ones { + let parent_has_one_key = Arc::new(has_one_values_key(parent, has_one)); + let parent_has_ones = self + .has_one_values + .entry(parent_has_one_key.as_str().to_owned()) + .or_default(); + + let has_one_values = self.values.get(has_one.as_str()).expect( + "add_has_ones should never be called before the values collection is created", + ); + + let mut ones_iter = has_one_values.iter(); + for parent in parent_values { + let one_val = ones_iter.next().unwrap_or_else(|| { + ones_iter = has_one_values.iter(); + ones_iter.next().unwrap() + }); + + let has_one_map = parent_has_ones.id_to_has_ones.entry(parent.id).or_default(); + has_one_map.insert(Arc::clone(&parent_has_one_key), Arc::clone(one_val)); + } + } + + Ok(()) + } + + fn generate_belongs_to( + &mut self, + template: &mut Handlebars<'static>, + belongs_to: &str, + spec: &ValuesSpec, + ) -> Result<()> { + let parent_values = self.values.get(belongs_to).expect( + "generate_belongs_to should never be called before the parent values are inserted", + ); + let tag_key = Arc::new(spec.name.clone()); + + let mut all_children = Vec::with_capacity(parent_values.len() * spec.cardinality); + + for parent in parent_values { + let mut parent_owned = Vec::with_capacity(spec.cardinality); + + for _ in 0..spec.cardinality { + let child_value_id = all_children.len() + 1; + let data = json!({ + belongs_to: { + "id": parent.id, + "value": &parent.tag_pair.value.as_ref(), + }, + "id": child_value_id, + }); + + let rendered_value = + template + .render(&spec.name, &data) + .context(CantRenderTemplateSnafu { + template: &spec.name, + })?; + let value = Arc::new(rendered_value); + + let child_value = Arc::new(GeneratedValue { + id: child_value_id, + tag_pair: Arc::new(StaticTagPair { + key: Arc::clone(&tag_key), + value, + }), + }); + + parent_owned.push(Arc::clone(&child_value)); + all_children.push(child_value); + } + + let child_vals = self + .child_values + .entry(child_values_key(belongs_to, &spec.name)) + .or_default(); + child_vals.insert(parent.id, parent_owned); + } + self.values.insert(spec.name.to_string(), all_children); + + Ok(()) + } +} + +struct Key<'a> { + name: &'a str, + value: String, + position: usize, +} + +fn child_values_key(parent: &str, child: &str) -> String { + format!("{parent}.{child}") +} + +fn has_one_values_key(parent: &str, child: &str) -> String { + format!("{parent}.{child}") +} + +/// A collection of tag key/value pairs +#[derive(Debug)] +pub struct TagSet { + /// The tags in the set + pub tags: Vec>, +} + +impl TagSet { + fn new(tags: Vec>) -> Self { + Self { tags } + } +} + +impl std::fmt::Display for TagSet { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let s = self.tags.iter().map(|t| t.to_string()).join(","); + write!(f, "{s}") + } +} + +#[cfg(test)] +mod test { + use super::*; + use std::str::FromStr; + + #[test] + fn generate_tag_sets_basic() { + let toml = r#" +name = "demo" + +[[values]] +name = "foo" +template = "{{id}}#foo" +cardinality = 3 + +[[tag_sets]] +name = "testage" +for_each = ["foo"] + +[[agents]] +name = "foo" + +[[agents.measurements]] +name = "cpu" + +[[agents.measurements.fields]] +name = "f1" +i64_range = [0, 23] + +[[database_writers]] +agents = [{name = "foo", sampling_interval = "10s"}]"#; + + let spec = DataSpec::from_str(toml).unwrap(); + let tag_sets = GeneratedTagSets::from_spec(&spec).unwrap(); + let testage = tag_sets.sets_for("testage").unwrap(); + let sets = testage.iter().map(|t| t.to_string()).join("\n"); + let expected = r#" +foo=1#foo +foo=2#foo +foo=3#foo"#; + assert_eq!(expected[1..], sets); + } + + #[test] + fn generate_tag_sets_belongs_to() { + let toml = r#" +name = "demo" + +[[values]] +name = "foo" +template = "{{id}}#foo" +cardinality = 2 + +[[values]] +name = "bar" +template = "{{id}}-{{foo.id}}-{{foo.value}}" +cardinality = 2 +belongs_to = "foo" + +[[tag_sets]] +name = "testage" +for_each = [ + "foo", + "foo.bar", +] + +[[agents]] +name = "foo" + +[[agents.measurements]] +name = "cpu" + +[[agents.measurements.fields]] +name = "f1" +i64_range = [0, 23] + +[[database_writers]] +agents = [{name = "foo", sampling_interval = "10s"}]"#; + + let spec = DataSpec::from_str(toml).unwrap(); + let tag_sets = GeneratedTagSets::from_spec(&spec).unwrap(); + let testage = tag_sets.sets_for("testage").unwrap(); + let sets = testage.iter().map(|t| t.to_string()).join("\n"); + let expected = r#" +bar=1-1-1#foo,foo=1#foo +bar=2-1-1#foo,foo=1#foo +bar=3-2-2#foo,foo=2#foo +bar=4-2-2#foo,foo=2#foo"#; + assert_eq!(expected[1..], sets); + } + + #[test] + fn generate_tag_sets_test() { + let toml = r#" +name = "demo" + +[[values]] +name = "foo" +template = "{{id}}-foo" +cardinality = 3 +has_one = ["bar"] + +[[values]] +name = "bar" +template = "{{id}}-bar" +cardinality = 2 + +[[values]] +name = "asdf" +template = "{{id}}-asdf" +cardinality = 2 +belongs_to = "foo" +has_one = ["qwer"] + +[[values]] +name = "jkl" +template = "{{id}}-jkl" +cardinality = 2 + +[[values]] +name = "qwer" +template = "{{id}}-qwer" +cardinality = 6 + +[[tag_sets]] +name = "testage" +for_each = [ + "foo", + "foo.bar", + "foo.asdf", + "asdf.qwer", + "jkl" +] + +[[agents]] +name = "foo" + +[[agents.measurements]] +name = "cpu" + +[[agents.measurements.fields]] +name = "f1" +i64_range = [0, 23] + +[[database_writers]] +database_ratio = 1.0 +agents = [{name = "foo", sampling_interval = "10s"}]"#; + + let spec = DataSpec::from_str(toml).unwrap(); + let tag_sets = GeneratedTagSets::from_spec(&spec).unwrap(); + let testage = tag_sets.sets_for("testage").unwrap(); + let sets = testage.iter().map(|t| t.to_string()).join("\n"); + let expected = r#" +asdf=1-asdf,bar=1-bar,foo=1-foo,jkl=1-jkl,qwer=1-qwer +asdf=1-asdf,bar=1-bar,foo=1-foo,jkl=2-jkl,qwer=1-qwer +asdf=2-asdf,bar=1-bar,foo=1-foo,jkl=1-jkl,qwer=2-qwer +asdf=2-asdf,bar=1-bar,foo=1-foo,jkl=2-jkl,qwer=2-qwer +asdf=3-asdf,bar=2-bar,foo=2-foo,jkl=1-jkl,qwer=3-qwer +asdf=3-asdf,bar=2-bar,foo=2-foo,jkl=2-jkl,qwer=3-qwer +asdf=4-asdf,bar=2-bar,foo=2-foo,jkl=1-jkl,qwer=4-qwer +asdf=4-asdf,bar=2-bar,foo=2-foo,jkl=2-jkl,qwer=4-qwer +asdf=5-asdf,bar=1-bar,foo=3-foo,jkl=1-jkl,qwer=5-qwer +asdf=5-asdf,bar=1-bar,foo=3-foo,jkl=2-jkl,qwer=5-qwer +asdf=6-asdf,bar=1-bar,foo=3-foo,jkl=1-jkl,qwer=6-qwer +asdf=6-asdf,bar=1-bar,foo=3-foo,jkl=2-jkl,qwer=6-qwer"#; + assert_eq!(expected[1..], sets); + } +} diff --git a/iox_data_generator/src/write.rs b/iox_data_generator/src/write.rs new file mode 100644 index 0000000..1b0f701 --- /dev/null +++ b/iox_data_generator/src/write.rs @@ -0,0 +1,543 @@ +//! Writing generated points + +use crate::measurement::LineToGenerate; +use bytes::Bytes; +use datafusion_util::{unbounded_memory_pool, MemoryStream}; +use futures::stream; +use influxdb2_client::models::WriteDataPoint; +use mutable_batch_lp::lines_to_batches; +use parquet_file::{metadata::IoxMetadata, serialize}; +use schema::Projection; +use snafu::{ensure, ResultExt, Snafu}; +#[cfg(test)] +use std::{collections::BTreeMap, sync::Arc}; +use std::{ + fs::{self, File, OpenOptions}, + io::{BufWriter, Write}, + path::{Path, PathBuf}, + sync::Mutex, +}; + +/// Errors that may happen while writing points. +#[derive(Snafu, Debug)] +pub enum Error { + /// Error that may happen when writing line protocol to a file + #[snafu(display("Couldn't open line protocol file {}: {}", filename.display(), source))] + CantOpenLineProtocolFile { + /// The location of the file we tried to open + filename: PathBuf, + /// Underlying IO error that caused this problem + source: std::io::Error, + }, + + /// Error that may happen when writing Parquet to a file + #[snafu(display("Couldn't open Parquet file {}: {}", filename.display(), source))] + CantOpenParquetFile { + /// The location of the file we tried to open + filename: PathBuf, + /// Underlying IO error that caused this problem + source: std::io::Error, + }, + + /// Error that may happen when writing line protocol to a no-op sink + #[snafu(display("Could not generate line protocol: {}", source))] + CantWriteToNoOp { + /// Underlying IO error that caused this problem + source: std::io::Error, + }, + + /// Error that may happen when writing line protocol to a file + #[snafu(display("Could not write line protocol to file: {}", source))] + CantWriteToLineProtocolFile { + /// Underlying IO error that caused this problem + source: std::io::Error, + }, + + /// Error that may happen when writing line protocol to a Vec of bytes + #[snafu(display("Could not write to vec: {}", source))] + WriteToVec { + /// Underlying IO error that caused this problem + source: std::io::Error, + }, + + /// Error that may happen when writing Parquet to a file + #[snafu(display("Could not write Parquet: {}", source))] + WriteToParquetFile { + /// Underlying IO error that caused this problem + source: std::io::Error, + }, + + /// Error that may happen when converting line protocol to a mutable batch + #[snafu(display("Could not convert to a mutable batch: {}", source))] + ConvertToMutableBatch { + /// Underlying mutable_batch_lp error that caused this problem + source: mutable_batch_lp::Error, + }, + + /// Error that may happen when converting a mutable batch to an Arrow RecordBatch + #[snafu(display("Could not convert to a record batch: {}", source))] + ConvertToArrow { + /// Underlying mutable_batch error that caused this problem + source: mutable_batch::Error, + }, + + /// Error that may happen when creating a directory to store files to write + /// to + #[snafu(display("Could not create directory: {}", source))] + CantCreateDirectory { + /// Underlying IO error that caused this problem + source: std::io::Error, + }, + + /// Error that may happen when checking a path's metadata to see if it's a + /// directory + #[snafu(display("Could not get metadata: {}", source))] + CantGetMetadata { + /// Underlying IO error that caused this problem + source: std::io::Error, + }, + + /// Error that may happen if the path given to the file-based writer isn't a + /// directory + #[snafu(display("Expected to get a directory"))] + MustBeDirectory, + + /// Error that may happen while writing points to the API + #[snafu(display("Could not write points to API: {}", source))] + CantWriteToApi { + /// Underlying Influx client request error that caused this problem + source: influxdb2_client::RequestError, + }, + + /// Error that may happen while trying to create a bucket via the API + #[snafu(display("Could not create bucket: {}", source))] + CantCreateBucket { + /// Underlying Influx client request error that caused this problem + source: influxdb2_client::RequestError, + }, + + /// Error that may happen if attempting to create a bucket without + /// specifying the org ID + #[snafu(display("Could not create a bucket without an `org_id`"))] + OrgIdRequiredToCreateBucket, + + /// Error that may happen when serializing to Parquet + #[snafu(display("Could not serialize to Parquet"))] + ParquetSerialization { + /// Underlying `parquet_file` error that caused this problem + source: parquet_file::serialize::CodecError, + }, +} + +type Result = std::result::Result; + +/// Responsible for holding shared configuration needed to construct per-agent +/// points writers +#[derive(Debug)] +pub struct PointsWriterBuilder { + config: PointsWriterConfig, +} + +#[derive(Debug)] +enum PointsWriterConfig { + Api(influxdb2_client::Client), + Directory(PathBuf), + ParquetFile(PathBuf), + NoOp { + perform_write: bool, + }, + #[cfg(test)] + Vector(BTreeMap>>>), + Stdout, +} + +impl PointsWriterBuilder { + /// Write points to the API at the specified host and put them in the + /// specified org and bucket. + pub async fn new_api( + host: impl Into + Send, + token: impl Into + Send, + jaeger_debug: Option<&str>, + ) -> Result { + let host = host.into(); + + // Be somewhat lenient on what we accept as far as host; the client expects the + // protocol to be included. We could pull in the url crate and do more + // verification here. + let host = if host.starts_with("http") { + host + } else { + format!("http://{host}") + }; + + let mut client = influxdb2_client::Client::new(host, token.into()); + if let Some(header) = jaeger_debug { + client = client.with_jaeger_debug(header.to_string()); + } + + Ok(Self { + config: PointsWriterConfig::Api(client), + }) + } + + /// Write points to a file in the directory specified. + pub fn new_file>(path: P) -> Result { + fs::create_dir_all(&path).context(CantCreateDirectorySnafu)?; + let metadata = fs::metadata(&path).context(CantGetMetadataSnafu)?; + ensure!(metadata.is_dir(), MustBeDirectorySnafu); + + Ok(Self { + config: PointsWriterConfig::Directory(PathBuf::from(path.as_ref())), + }) + } + + /// Write points to a Parquet file in the directory specified. + pub fn new_parquet>(path: P) -> Result { + fs::create_dir_all(&path).context(CantCreateDirectorySnafu)?; + let metadata = fs::metadata(&path).context(CantGetMetadataSnafu)?; + ensure!(metadata.is_dir(), MustBeDirectorySnafu); + + Ok(Self { + config: PointsWriterConfig::ParquetFile(PathBuf::from(path.as_ref())), + }) + } + + /// Write points to stdout + pub fn new_std_out() -> Self { + Self { + config: PointsWriterConfig::Stdout, + } + } + + /// Generate points but do not write them anywhere + pub fn new_no_op(perform_write: bool) -> Self { + Self { + config: PointsWriterConfig::NoOp { perform_write }, + } + } + + /// Create a writer out of this writer's configuration for a particular + /// agent that runs in a separate thread/task. + pub fn build_for_agent( + &mut self, + name: impl Into, + org: impl Into, + bucket: impl Into, + ) -> Result { + let inner_writer = match &mut self.config { + PointsWriterConfig::Api(client) => InnerPointsWriter::Api { + client: client.clone(), + org: org.into(), + bucket: bucket.into(), + }, + PointsWriterConfig::Directory(dir_path) => { + let mut filename = dir_path.clone(); + filename.push(name.into()); + filename.set_extension("txt"); + + let file = OpenOptions::new() + .append(true) + .create(true) + .open(&filename) + .context(CantOpenLineProtocolFileSnafu { filename })?; + + let file = Mutex::new(BufWriter::new(file)); + + InnerPointsWriter::File { file } + } + + PointsWriterConfig::ParquetFile(dir_path) => InnerPointsWriter::ParquetFile { + dir_path: dir_path.clone(), + agent_name: name.into(), + }, + + PointsWriterConfig::NoOp { perform_write } => InnerPointsWriter::NoOp { + perform_write: *perform_write, + }, + #[cfg(test)] + PointsWriterConfig::Vector(ref mut agents_by_name) => { + let v = agents_by_name + .entry(name.into()) + .or_insert_with(|| Arc::new(Mutex::new(Vec::new()))); + InnerPointsWriter::Vec(Arc::clone(v)) + } + PointsWriterConfig::Stdout => InnerPointsWriter::Stdout, + }; + + Ok(PointsWriter { inner_writer }) + } +} + +/// Responsible for writing points to the location it's been configured for. +#[derive(Debug)] +pub struct PointsWriter { + inner_writer: InnerPointsWriter, +} + +impl PointsWriter { + /// Write these points + pub async fn write_points( + &self, + points: impl Iterator + Send + Sync + 'static, + ) -> Result<()> { + self.inner_writer.write_points(points).await + } +} + +#[derive(Debug)] +enum InnerPointsWriter { + Api { + client: influxdb2_client::Client, + org: String, + bucket: String, + }, + File { + file: Mutex>, + }, + ParquetFile { + dir_path: PathBuf, + agent_name: String, + }, + NoOp { + perform_write: bool, + }, + #[cfg(test)] + Vec(Arc>>), + Stdout, +} + +impl InnerPointsWriter { + async fn write_points( + &self, + points: impl Iterator + Send + Sync + 'static, + ) -> Result<()> { + match self { + Self::Api { + client, + org, + bucket, + } => { + client + .write(org, bucket, stream::iter(points)) + .await + .context(CantWriteToApiSnafu)?; + } + Self::File { file } => { + for point in points { + let mut file = file.lock().expect("Should be able to get lock"); + point + .write_data_point_to(&mut *file) + .context(CantWriteToLineProtocolFileSnafu)?; + } + } + + Self::ParquetFile { + dir_path, + agent_name, + } => { + let mut raw_line_protocol = Vec::new(); + for point in points { + point + .write_data_point_to(&mut raw_line_protocol) + .context(WriteToVecSnafu)?; + } + let line_protocol = String::from_utf8(raw_line_protocol) + .expect("Generator should be creating valid UTF-8"); + + let batches_by_measurement = + lines_to_batches(&line_protocol, 0).context(ConvertToMutableBatchSnafu)?; + + for (measurement, batch) in batches_by_measurement { + let record_batch = batch + .to_arrow(Projection::All) + .context(ConvertToArrowSnafu)?; + let stream = Box::pin(MemoryStream::new(vec![record_batch])); + let meta = IoxMetadata::external(crate::now_ns(), &*measurement); + let pool = unbounded_memory_pool(); + let (data, _parquet_file_meta) = + serialize::to_parquet_bytes(stream, &meta, pool) + .await + .context(ParquetSerializationSnafu)?; + let data = Bytes::from(data); + + let mut filename = dir_path.clone(); + filename.push(format!("{agent_name}_{measurement}")); + filename.set_extension("parquet"); + + let file = OpenOptions::new() + .create(true) + .write(true) + .open(&filename) + .context(CantOpenParquetFileSnafu { filename })?; + + let mut file = BufWriter::new(file); + + file.write_all(&data).context(WriteToParquetFileSnafu)?; + } + } + + Self::NoOp { perform_write } => { + if *perform_write { + let mut sink = std::io::sink(); + + for point in points { + point + .write_data_point_to(&mut sink) + .context(CantWriteToNoOpSnafu)?; + } + } + } + #[cfg(test)] + Self::Vec(vec) => { + let vec_ref = Arc::clone(vec); + let mut vec = vec_ref.lock().expect("Should be able to get lock"); + for point in points { + point + .write_data_point_to(&mut *vec) + .expect("Should be able to write to vec"); + } + } + Self::Stdout => { + for point in points { + point + .write_data_point_to(std::io::stdout()) + .expect("should be able to write to stdout"); + } + } + } + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{generate, now_ns, specification::*}; + use std::str::FromStr; + + type Error = Box; + type Result = std::result::Result; + + impl PointsWriterBuilder { + fn new_vec() -> Self { + Self { + config: PointsWriterConfig::Vector(BTreeMap::new()), + } + } + + fn written_data(self, agent_name: &str) -> String { + match self.config { + PointsWriterConfig::Vector(agents_by_name) => { + let bytes_ref = + Arc::clone(agents_by_name.get(agent_name).expect( + "Should have written some data, did not find any for this agent", + )); + let bytes = bytes_ref + .lock() + .expect("Should have been able to get a lock"); + String::from_utf8(bytes.to_vec()).expect("we should be generating valid UTF-8") + } + _ => unreachable!("this method is only valid when writing to a vector for testing"), + } + } + } + + #[tokio::test] + async fn test_generate() -> Result<()> { + let toml = r#" +name = "demo_schema" + +[[agents]] +name = "foo" + +[[agents.measurements]] +name = "cpu" + +[[agents.measurements.fields]] +name = "val" +i64_range = [3,3] + +[[database_writers]] +agents = [{name = "foo", sampling_interval = "1s"}] +"#; + + let data_spec = DataSpec::from_str(toml).unwrap(); + let mut points_writer_builder = PointsWriterBuilder::new_vec(); + + let now = now_ns(); + + generate( + &data_spec, + vec!["foo_bar".to_string()], + &mut points_writer_builder, + Some(now), + Some(now), + now, + false, + 1, + false, + ) + .await?; + + let line_protocol = points_writer_builder.written_data("foo"); + + let expected_line_protocol = format!( + r#"cpu val=3i {now} +"# + ); + assert_eq!(line_protocol, expected_line_protocol); + + Ok(()) + } + + #[tokio::test] + async fn test_generate_batches() -> Result<()> { + let toml = r#" +name = "demo_schema" + +[[agents]] +name = "foo" + +[[agents.measurements]] +name = "cpu" + +[[agents.measurements.fields]] +name = "val" +i64_range = [2, 2] + +[[database_writers]] +agents = [{name = "foo", sampling_interval = "1s"}] +"#; + + let data_spec = DataSpec::from_str(toml).unwrap(); + let mut points_writer_builder = PointsWriterBuilder::new_vec(); + + let now = now_ns(); + + generate( + &data_spec, + vec!["foo_bar".to_string()], + &mut points_writer_builder, + Some(now - 1_000_000_000), + Some(now), + now, + false, + 2, + false, + ) + .await?; + + let line_protocol = points_writer_builder.written_data("foo"); + + let expected_line_protocol = format!( + r#"cpu val=2i {} +cpu val=2i {} +"#, + now - 1_000_000_000, + now + ); + assert_eq!(line_protocol, expected_line_protocol); + + Ok(()) + } +} diff --git a/iox_query/Cargo.toml b/iox_query/Cargo.toml new file mode 100644 index 0000000..e453531 --- /dev/null +++ b/iox_query/Cargo.toml @@ -0,0 +1,55 @@ +[package] +name = "iox_query" +description = "IOx Query Interface and Executor" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[lints] +workspace = true + +# This crate is designed to be independent of the rest of the IOx +# server and specific storage systems such as Mutable Buffer and Read Buffer. +# +# The rationale for this is to: +# +# 1. Keep change/compile/link time down during development when working on just this crate +# 2. Allow for query logic testing without bringing in all the storage systems. + +[dependencies] # In alphabetical order +arrow = { workspace = true } +arrow_util = { path = "../arrow_util" } +async-trait = "0.1" +chrono = { version = "0.4", default-features = false } +data_types = { path = "../data_types" } +datafusion = { workspace = true } +datafusion_util = { path = "../datafusion_util" } +executor = { path = "../executor"} +futures = "0.3" +hashbrown = { workspace = true } +indexmap = { version = "2.1", features = ["std"] } +itertools = "0.12.0" +iox_time = { path = "../iox_time" } +metric = { path = "../metric" } +object_store = { workspace = true } +observability_deps = { path = "../observability_deps" } +once_cell = "1" +parking_lot = "0.12" +parquet_file = { path = "../parquet_file" } +query_functions = { path = "../query_functions"} +schema = { path = "../schema" } +snafu = "0.8" +tokio = { version = "1.35", features = ["macros", "parking_lot"] } +tokio-stream = "0.1" +trace = { path = "../trace" } +tracker = { path = "../tracker" } +predicate = { path = "../predicate" } +uuid = { version = "1", features = ["v4"] } +workspace-hack = { version = "0.1", path = "../workspace-hack" } + +[dev-dependencies] # In alphabetical order +test_helpers = { path = "../test_helpers" } +assert_matches = "1" +insta = { version = "1", features = ["yaml"] } +serde = { version = "1.0", features = ["derive"] } diff --git a/iox_query/README.md b/iox_query/README.md new file mode 100644 index 0000000..c522983 --- /dev/null +++ b/iox_query/README.md @@ -0,0 +1,3 @@ +# IOx Query Layer + +See [InfluxDB IOx -- Query Processing](../docs/query_processing.md) for details. diff --git a/iox_query/src/chunk_statistics.rs b/iox_query/src/chunk_statistics.rs new file mode 100644 index 0000000..0430347 --- /dev/null +++ b/iox_query/src/chunk_statistics.rs @@ -0,0 +1,289 @@ +//! Tools to set up DataFusion statistics. + +use std::{collections::HashMap, sync::Arc}; + +use data_types::TimestampMinMax; +use datafusion::common::stats::Precision; +use datafusion::{ + physical_plan::{ColumnStatistics, Statistics}, + scalar::ScalarValue, +}; +use datafusion_util::{option_to_precision, timestamptz_nano}; +use schema::{InfluxColumnType, Schema}; + +/// Represent known min/max values for a specific column. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ColumnRange { + pub min_value: Arc, + pub max_value: Arc, +} + +/// Represents the known min/max values for a subset (not all) of the columns in a partition. +/// +/// The values may not actually in any row. +/// +/// These ranges apply to ALL rows (esp. in ALL files and ingester chunks) within in given partition. +pub type ColumnRanges = Arc, ColumnRange>>; + +/// Returns the min/max values for the range, if present +fn range_to_min_max_stats( + range: Option<&ColumnRange>, +) -> (Precision, Precision) { + let Some(range) = range else { + return (Precision::Absent, Precision::Absent); + }; + ( + Precision::Exact(range.min_value.as_ref().clone()), + Precision::Exact(range.max_value.as_ref().clone()), + ) +} + +/// Create chunk [statistics](Statistics). +pub fn create_chunk_statistics( + row_count: Option, + schema: &Schema, + ts_min_max: Option, + ranges: Option<&ColumnRanges>, +) -> Statistics { + let mut columns = Vec::with_capacity(schema.len()); + + for (t, field) in schema.iter() { + let stats = match t { + InfluxColumnType::Timestamp => { + // prefer explicitely given time range but fall back to column ranges + let (min_value, max_value) = match ts_min_max { + Some(ts_min_max) => ( + Precision::Exact(timestamptz_nano(ts_min_max.min)), + Precision::Exact(timestamptz_nano(ts_min_max.max)), + ), + None => { + let range = + ranges.and_then(|ranges| ranges.get::(field.name().as_ref())); + + range_to_min_max_stats(range) + } + }; + + ColumnStatistics { + null_count: Precision::Exact(0), + min_value, + max_value, + distinct_count: Precision::Absent, + } + } + _ => { + let range = ranges.and_then(|ranges| ranges.get::(field.name().as_ref())); + + let (min_value, max_value) = range_to_min_max_stats(range); + + ColumnStatistics { + null_count: Precision::Absent, + min_value, + max_value, + distinct_count: Precision::Absent, + } + } + }; + columns.push(stats) + } + + let num_rows = option_to_precision(row_count); + + Statistics { + num_rows, + total_byte_size: Precision::Absent, + column_statistics: columns, + } +} + +#[cfg(test)] +mod tests { + use schema::{InfluxFieldType, SchemaBuilder, TIME_COLUMN_NAME}; + + use super::*; + + #[test] + fn test_create_chunk_statistics_no_columns_no_rows() { + let schema = SchemaBuilder::new().build().unwrap(); + let row_count = 0; + + let actual = create_chunk_statistics(Some(row_count), &schema, None, None); + let expected = Statistics { + num_rows: Precision::Exact(row_count), + total_byte_size: Precision::Absent, + column_statistics: vec![], + }; + assert_eq!(actual, expected); + } + + #[test] + fn test_create_chunk_statistics_no_columns_null_rows() { + let schema = SchemaBuilder::new().build().unwrap(); + + let actual = create_chunk_statistics(None, &schema, None, None); + let expected = Statistics { + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![], + }; + assert_eq!(actual, expected); + } + + #[test] + fn test_create_chunk_statistics() { + let schema = full_schema(); + let ts_min_max = TimestampMinMax { min: 10, max: 20 }; + let ranges = Arc::new(HashMap::from([ + ( + Arc::from("tag1"), + ColumnRange { + min_value: Arc::new(ScalarValue::from("aaa")), + max_value: Arc::new(ScalarValue::from("bbb")), + }, + ), + ( + Arc::from("tag3"), // does not exist in schema + ColumnRange { + min_value: Arc::new(ScalarValue::from("ccc")), + max_value: Arc::new(ScalarValue::from("ddd")), + }, + ), + ( + Arc::from("field_integer"), + ColumnRange { + min_value: Arc::new(ScalarValue::from(10i64)), + max_value: Arc::new(ScalarValue::from(20i64)), + }, + ), + ])); + + for row_count in [0usize, 1337usize] { + let actual = + create_chunk_statistics(Some(row_count), &schema, Some(ts_min_max), Some(&ranges)); + let expected = Statistics { + num_rows: Precision::Exact(row_count), + total_byte_size: Precision::Absent, + column_statistics: vec![ + // tag1 + ColumnStatistics { + null_count: Precision::Absent, + min_value: Precision::Exact(ScalarValue::from("aaa")), + max_value: Precision::Exact(ScalarValue::from("bbb")), + distinct_count: Precision::Absent, + }, + // tag2 + ColumnStatistics::default(), + // field_bool + ColumnStatistics::default(), + // field_float + ColumnStatistics::default(), + // field_integer + ColumnStatistics { + null_count: Precision::Absent, + min_value: Precision::Exact(ScalarValue::from(10i64)), + max_value: Precision::Exact(ScalarValue::from(20i64)), + distinct_count: Precision::Absent, + }, + // field_string + ColumnStatistics::default(), + // field_uinteger + ColumnStatistics::default(), + // time + ColumnStatistics { + null_count: Precision::Exact(0), + min_value: Precision::Exact(timestamptz_nano(10)), + max_value: Precision::Exact(timestamptz_nano(20)), + distinct_count: Precision::Absent, + }, + ], + }; + assert_eq!(actual, expected); + } + } + + #[test] + fn test_create_chunk_statistics_ts_min_max_overrides_column_range() { + let schema = full_schema(); + let row_count = 42usize; + let ts_min_max = TimestampMinMax { min: 10, max: 20 }; + let ranges = Arc::new(HashMap::from([( + Arc::from(TIME_COLUMN_NAME), + ColumnRange { + min_value: Arc::new(timestamptz_nano(12)), + max_value: Arc::new(timestamptz_nano(22)), + }, + )])); + + let actual = + create_chunk_statistics(Some(row_count), &schema, Some(ts_min_max), Some(&ranges)); + let expected = Statistics { + num_rows: Precision::Exact(row_count), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics::default(), + ColumnStatistics::default(), + ColumnStatistics::default(), + ColumnStatistics::default(), + ColumnStatistics::default(), + ColumnStatistics::default(), + ColumnStatistics::default(), + ColumnStatistics { + null_count: Precision::Exact(0), + min_value: Precision::Exact(timestamptz_nano(10)), + max_value: Precision::Exact(timestamptz_nano(20)), + distinct_count: Precision::Absent, + }, + ], + }; + assert_eq!(actual, expected); + } + + #[test] + fn test_create_chunk_statistics_ts_min_max_none_so_fallback_to_column_range() { + let schema = full_schema(); + let row_count = 42usize; + let ranges = Arc::new(HashMap::from([( + Arc::from(TIME_COLUMN_NAME), + ColumnRange { + min_value: Arc::new(timestamptz_nano(12)), + max_value: Arc::new(timestamptz_nano(22)), + }, + )])); + + let actual = create_chunk_statistics(Some(row_count), &schema, None, Some(&ranges)); + let expected = Statistics { + num_rows: Precision::Exact(row_count), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics::default(), + ColumnStatistics::default(), + ColumnStatistics::default(), + ColumnStatistics::default(), + ColumnStatistics::default(), + ColumnStatistics::default(), + ColumnStatistics::default(), + ColumnStatistics { + null_count: Precision::Exact(0), + min_value: Precision::Exact(timestamptz_nano(12)), + max_value: Precision::Exact(timestamptz_nano(22)), + distinct_count: Precision::Absent, + }, + ], + }; + assert_eq!(actual, expected); + } + + fn full_schema() -> Schema { + SchemaBuilder::new() + .tag("tag1") + .tag("tag2") + .influx_field("field_bool", InfluxFieldType::Boolean) + .influx_field("field_float", InfluxFieldType::Float) + .influx_field("field_integer", InfluxFieldType::Integer) + .influx_field("field_string", InfluxFieldType::String) + .influx_field("field_uinteger", InfluxFieldType::UInteger) + .timestamp() + .build() + .unwrap() + } +} diff --git a/iox_query/src/config.rs b/iox_query/src/config.rs new file mode 100644 index 0000000..6dc2235 --- /dev/null +++ b/iox_query/src/config.rs @@ -0,0 +1,94 @@ +use std::{str::FromStr, time::Duration}; + +use datafusion::{common::extensions_options, config::ConfigExtension}; + +/// IOx-specific config extension prefix. +pub const IOX_CONFIG_PREFIX: &str = "iox"; + +extensions_options! { + /// Config options for IOx. + pub struct IoxConfigExt { + /// When splitting de-duplicate operations based on IOx partitions[^iox_part], this is the maximum number of IOx + /// partitions that should be considered. If there are more partitions, the split will NOT be performed. + /// + /// This protects against certain highly degenerative plans. + /// + /// + /// [^iox_part]: "IOx partition" refers to a partition within the IOx catalog, i.e. a partition within the + /// primary key space. This is NOT the same as a DataFusion partition which refers to a stream + /// within the physical plan data flow. + pub max_dedup_partition_split: usize, default = 10_000 + + /// When splitting de-duplicate operations based on time-based overlaps, this is the maximum number of groups + /// that should be considered. If there are more groups, the split will NOT be performed. + /// + /// This protects against certain highly degenerative plans. + pub max_dedup_time_split: usize, default = 100 + + /// When multiple parquet files are required in a sorted way (e.g. for de-duplication), we have two options: + /// + /// 1. **In-mem sorting:** Put them into [`target_partitions`] DataFusion partitions. This limits the fan-out, + /// but requires that we potentially chain multiple parquet files into a single DataFusion partition. Since + /// chaining sorted data does NOT automatically result in sorted data (e.g. AB-AB is not sorted), we need to + /// preform an in-memory sort using [`SortExec`] afterwards. This is expensive. + /// 2. **Fan-out:** Instead of chaining files within DataFusion partitions, we can accept a fan-out beyond + /// [`target_partitions`]. This prevents in-memory sorting but may result in OOMs (out-of-memory). + /// + /// We try to pick option 2 up to a certain number of files, which is configured by this setting. + /// + /// + /// [`SortExec`]: datafusion::physical_plan::sorts::sort::SortExec + /// [`target_partitions`]: datafusion::common::config::ExecutionOptions::target_partitions + pub max_parquet_fanout: usize, default = 40 + + /// Cuttoff date for InfluxQL metadata queries. + pub influxql_metadata_cutoff: MetadataCutoff, default = MetadataCutoff::Relative(Duration::from_secs(3600 * 24)) + } +} + +impl ConfigExtension for IoxConfigExt { + const PREFIX: &'static str = IOX_CONFIG_PREFIX; +} + +/// Optional datetime. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MetadataCutoff { + Absolute(chrono::DateTime), + Relative(Duration), +} + +#[derive(Debug)] +pub struct ParseError(String); + +impl std::fmt::Display for ParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl std::error::Error for ParseError {} + +impl FromStr for MetadataCutoff { + type Err = ParseError; + + fn from_str(s: &str) -> Result { + if let Some(s) = s.strip_prefix('-') { + let delta = u64::from_str(s).map_err(|e| ParseError(e.to_string()))?; + let delta = Duration::from_nanos(delta); + Ok(Self::Relative(delta)) + } else { + let dt = chrono::DateTime::::from_str(s) + .map_err(|e| ParseError(e.to_string()))?; + Ok(Self::Absolute(dt)) + } + } +} + +impl std::fmt::Display for MetadataCutoff { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Relative(delta) => write!(f, "-{}", delta.as_nanos()), + Self::Absolute(dt) => write!(f, "{}", dt), + } + } +} diff --git a/iox_query/src/exec.rs b/iox_query/src/exec.rs new file mode 100644 index 0000000..abb8ba5 --- /dev/null +++ b/iox_query/src/exec.rs @@ -0,0 +1,814 @@ +//! This module handles the manipulation / execution of storage +//! plans. This is currently implemented using DataFusion, and this +//! interface abstracts away many of the details +pub(crate) mod context; +pub mod field; +pub mod fieldlist; +pub mod gapfill; +mod metrics; +mod non_null_checker; +pub mod query_tracing; +mod schema_pivot; +pub mod seriesset; +pub mod sleep; +pub(crate) mod split; +pub mod stringset; +use datafusion_util::config::register_iox_object_store; +use executor::DedicatedExecutor; +use metric::Registry; +use object_store::DynObjectStore; +use parquet_file::storage::StorageId; +mod cross_rt_stream; + +use std::{collections::HashMap, fmt::Display, num::NonZeroUsize, sync::Arc}; + +use datafusion::{ + self, + execution::{ + disk_manager::DiskManagerConfig, + memory_pool::MemoryPool, + runtime_env::{RuntimeConfig, RuntimeEnv}, + }, + logical_expr::{expr_rewriter::normalize_col, Extension}, + logical_expr::{Expr, LogicalPlan}, +}; + +pub use context::{IOxSessionConfig, IOxSessionContext, SessionContextIOxExt}; +use schema_pivot::SchemaPivotNode; + +use crate::exec::metrics::DataFusionMemoryPoolMetricsBridge; + +use self::{non_null_checker::NonNullCheckerNode, split::StreamSplitNode}; + +const TESTING_MEM_POOL_SIZE: usize = 1024 * 1024 * 1024; // 1GB + +/// Configuration for an Executor +#[derive(Debug, Clone)] +pub struct ExecutorConfig { + /// Number of threads per thread pool + pub num_threads: NonZeroUsize, + + /// Target parallelism for query execution + pub target_query_partitions: NonZeroUsize, + + /// Object stores + pub object_stores: HashMap>, + + /// Metric registry + pub metric_registry: Arc, + + /// Memory pool size in bytes. + pub mem_pool_size: usize, +} + +impl ExecutorConfig { + pub fn testing() -> Self { + Self { + num_threads: NonZeroUsize::new(1).unwrap(), + target_query_partitions: NonZeroUsize::new(1).unwrap(), + object_stores: HashMap::default(), + metric_registry: Arc::new(Registry::default()), + mem_pool_size: TESTING_MEM_POOL_SIZE, + } + } +} + +impl Display for ExecutorConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "num_threads={}, target_query_partitions={}, mem_pool_size={}", + self.num_threads, self.target_query_partitions, self.mem_pool_size + ) + } +} + +#[derive(Debug)] +pub struct DedicatedExecutors { + /// Executor for running user queries + query_exec: DedicatedExecutor, + + /// Executor for running system/reorganization tasks such as + /// compact + reorg_exec: DedicatedExecutor, + + /// Number of threads per thread pool + num_threads: NonZeroUsize, +} + +impl DedicatedExecutors { + pub fn new(num_threads: NonZeroUsize, metric_registry: Arc) -> Self { + let query_exec = + DedicatedExecutor::new("IOx Query", num_threads, Arc::clone(&metric_registry)); + let reorg_exec = DedicatedExecutor::new("IOx Reorg", num_threads, metric_registry); + + Self { + query_exec, + reorg_exec, + num_threads, + } + } + + pub fn new_testing() -> Self { + let query_exec = DedicatedExecutor::new_testing(); + let reorg_exec = DedicatedExecutor::new_testing(); + assert_eq!(query_exec.num_threads(), reorg_exec.num_threads()); + let num_threads = query_exec.num_threads(); + Self { + query_exec, + reorg_exec, + num_threads, + } + } + + pub fn num_threads(&self) -> NonZeroUsize { + self.num_threads + } +} + +/// Handles executing DataFusion plans, and marshalling the results into rust +/// native structures. +#[derive(Debug)] +pub struct Executor { + /// Executors + executors: Arc, + + /// The default configuration options with which to create contexts + config: ExecutorConfig, + + /// The DataFusion [RuntimeEnv] (including memory manager and disk + /// manager) used for all executions + runtime: Arc, +} + +impl Display for Executor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Executor({})", self.config) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ExecutorType { + /// Run using the pool for queries + Query, + + /// Run using the pool for system / reorganization tasks + Reorg, +} + +impl Executor { + /// Creates a new executor with a two dedicated thread pools, each + /// with num_threads + pub fn new( + num_threads: NonZeroUsize, + mem_pool_size: usize, + metric_registry: Arc, + ) -> Self { + Self::new_with_config(ExecutorConfig { + num_threads, + target_query_partitions: num_threads, + object_stores: HashMap::default(), + metric_registry, + mem_pool_size, + }) + } + + /// Create new executor based on a specific config. + pub fn new_with_config(config: ExecutorConfig) -> Self { + let executors = Arc::new(DedicatedExecutors::new( + config.num_threads, + Arc::clone(&config.metric_registry), + )); + Self::new_with_config_and_executors(config, executors) + } + + /// Get testing executor that runs a on single thread and a low memory bound + /// to preserve resources. + pub fn new_testing() -> Self { + let config = ExecutorConfig::testing(); + let executors = Arc::new(DedicatedExecutors::new_testing()); + Self::new_with_config_and_executors(config, executors) + } + + /// Low-level constructor. + /// + /// This is mostly useful if you wanna keep the executors (because they are quiet expensive to create) but need a fresh IOx runtime. + /// + /// # Panic + /// Panics if the number of threads in `executors` is different from `config`. + pub fn new_with_config_and_executors( + config: ExecutorConfig, + executors: Arc, + ) -> Self { + assert_eq!(config.num_threads, executors.num_threads); + + let runtime_config = RuntimeConfig::new() + .with_disk_manager(DiskManagerConfig::Disabled) + .with_memory_limit(config.mem_pool_size, 1.0); + + let runtime = Arc::new(RuntimeEnv::new(runtime_config).expect("creating runtime")); + for (id, store) in &config.object_stores { + register_iox_object_store(&runtime, id, Arc::clone(store)); + } + + // As there should only be a single memory pool for any executor, + // verify that there was no existing instrument registered (for another pool) + let mut created = false; + let created_captured = &mut created; + let bridge = + DataFusionMemoryPoolMetricsBridge::new(&runtime.memory_pool, config.mem_pool_size); + let bridge_ctor = move || { + *created_captured = true; + bridge + }; + config + .metric_registry + .register_instrument("datafusion_pool", bridge_ctor); + assert!( + created, + "More than one execution pool created: previously existing instrument" + ); + + Self { + executors, + config, + runtime, + } + } + + /// Return a new execution config, suitable for executing a new query or system task. + /// + /// Note that this context (and all its clones) will be shut down once `Executor` is dropped. + pub fn new_execution_config(&self, executor_type: ExecutorType) -> IOxSessionConfig { + let exec = self.executor(executor_type).clone(); + IOxSessionConfig::new(exec, Arc::clone(&self.runtime)) + .with_target_partitions(self.config.target_query_partitions) + } + + /// Create a new execution context, suitable for executing a new query or system task + /// + /// Note that this context (and all its clones) will be shut down once `Executor` is dropped. + pub fn new_context(&self, executor_type: ExecutorType) -> IOxSessionContext { + self.new_execution_config(executor_type).build() + } + + /// Return the execution pool of the specified type + pub fn executor(&self, executor_type: ExecutorType) -> &DedicatedExecutor { + match executor_type { + ExecutorType::Query => &self.executors.query_exec, + ExecutorType::Reorg => &self.executors.reorg_exec, + } + } + + /// Initializes shutdown. + pub fn shutdown(&self) { + self.executors.query_exec.shutdown(); + self.executors.reorg_exec.shutdown(); + } + + /// Stops all subsequent task executions, and waits for the worker + /// thread to complete. Note this will shutdown all created contexts. + /// + /// Only the first all to `join` will actually wait for the + /// executing thread to complete. All other calls to join will + /// complete immediately. + pub async fn join(&self) { + self.executors.query_exec.join().await; + self.executors.reorg_exec.join().await; + } + + /// Returns the memory pool associated with this `Executor` + pub fn pool(&self) -> Arc { + Arc::clone(&self.runtime.memory_pool) + } + + /// Returns underlying config. + pub fn config(&self) -> &ExecutorConfig { + &self.config + } +} + +// No need to implement `Drop` because this is done by DedicatedExecutor already + +/// Create a SchemaPivot node which an arbitrary input like +/// ColA | ColB | ColC +/// ------+------+------ +/// 1 | NULL | NULL +/// 2 | 2 | NULL +/// 3 | 2 | NULL +/// +/// And pivots it to a table with a single string column for any +/// columns that had non null values. +/// +/// non_null_column +/// ----------------- +/// "ColA" +/// "ColB" +pub fn make_schema_pivot(input: LogicalPlan) -> LogicalPlan { + let node = Arc::new(SchemaPivotNode::new(input)); + + LogicalPlan::Extension(Extension { node }) +} + +/// Make a NonNullChecker node takes an arbitrary input array and +/// produces a single string output column that contains +/// +/// 1. the single `table_name` string if any of the input columns are non-null +/// 2. zero rows if all of the input columns are null +/// +/// For this input: +/// +/// ColA | ColB | ColC +/// ------+------+------ +/// 1 | NULL | NULL +/// 2 | 2 | NULL +/// 3 | 2 | NULL +/// +/// The output would be (given 'the_table_name' was the table name) +/// +/// non_null_column +/// ----------------- +/// the_table_name +/// +/// However, for this input (All NULL) +/// +/// ColA | ColB | ColC +/// ------+------+------ +/// NULL | NULL | NULL +/// NULL | NULL | NULL +/// NULL | NULL | NULL +/// +/// There would be no output rows +/// +/// non_null_column +/// ----------------- +pub fn make_non_null_checker(table_name: &str, input: LogicalPlan) -> LogicalPlan { + let node = Arc::new(NonNullCheckerNode::new(table_name, input)); + + LogicalPlan::Extension(Extension { node }) +} + +/// Create a StreamSplit node which takes an input stream of record +/// batches and produces multiple output streams based on a list of `N` predicates. +/// The output will have `N+1` streams, and each row is sent to the stream +/// corresponding to the first predicate that evaluates to true, or the last stream if none do. +/// +/// For example, if the input looks like: +/// ```text +/// X | time +/// ---+----- +/// a | 1000 +/// b | 4000 +/// c | 2000 +/// ``` +/// +/// A StreamSplit with split_exprs = [`time <= 1000`, `1000 < time <=2000`] will produce the +/// following three output streams (output DataFusion Partitions): +/// +/// +/// ```text +/// X | time +/// ---+----- +/// a | 1000 +/// ``` +/// +/// ```text +/// X | time +/// ---+----- +/// b | 2000 +/// ``` +/// and +/// ```text +/// X | time +/// ---+----- +/// b | 4000 +/// ``` +pub fn make_stream_split(input: LogicalPlan, split_exprs: Vec) -> LogicalPlan { + // rewrite the input expression so that it is fully qualified with the input schema + let split_exprs = split_exprs + .into_iter() + .map(|split_expr| normalize_col(split_expr, &input).expect("normalize is infallable")) + .collect::>(); + + let node = Arc::new(StreamSplitNode::new(input, split_exprs)); + LogicalPlan::Extension(Extension { node }) +} + +#[cfg(test)] +mod tests { + use arrow::{ + array::{ArrayRef, Int64Array, StringArray}, + datatypes::{DataType, Field, Schema, SchemaRef}, + }; + use datafusion::{ + datasource::{provider_as_source, MemTable}, + error::DataFusionError, + logical_expr::LogicalPlanBuilder, + physical_expr::PhysicalSortExpr, + physical_plan::{ + expressions::Column, sorts::sort::SortExec, DisplayAs, ExecutionPlan, RecordBatchStream, + }, + }; + use futures::{stream::BoxStream, Stream, StreamExt}; + use metric::{Observation, RawReporter}; + use stringset::StringSet; + use tokio::sync::Barrier; + + use super::*; + use crate::exec::stringset::StringSetRef; + use crate::plan::stringset::StringSetPlan; + use arrow::record_batch::RecordBatch; + + #[tokio::test] + async fn executor_known_string_set_plan_ok() { + let expected_strings = to_set(&["Foo", "Bar"]); + let plan = StringSetPlan::Known(Arc::clone(&expected_strings)); + + let exec = Executor::new_testing(); + let ctx = exec.new_context(ExecutorType::Query); + let result_strings = ctx.to_string_set(plan).await.unwrap(); + assert_eq!(result_strings, expected_strings); + } + + #[tokio::test] + async fn executor_datafusion_string_set_single_plan_no_batches() { + // Test with a single plan that produces no batches + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + let scan = make_plan(schema, vec![]); + let plan: StringSetPlan = vec![scan].into(); + + let exec = Executor::new_testing(); + let ctx = exec.new_context(ExecutorType::Query); + let results = ctx.to_string_set(plan).await.unwrap(); + + assert_eq!(results, StringSetRef::new(StringSet::new())); + } + + #[tokio::test] + async fn executor_datafusion_string_set_single_plan_one_batch() { + // Test with a single plan that produces one record batch + let data = to_string_array(&["foo", "bar", "baz", "foo"]); + let batch = RecordBatch::try_from_iter_with_nullable(vec![("a", data, true)]) + .expect("created new record batch"); + let scan = make_plan(batch.schema(), vec![batch]); + let plan: StringSetPlan = vec![scan].into(); + + let exec = Executor::new_testing(); + let ctx = exec.new_context(ExecutorType::Query); + let results = ctx.to_string_set(plan).await.unwrap(); + + assert_eq!(results, to_set(&["foo", "bar", "baz"])); + } + + #[tokio::test] + async fn executor_datafusion_string_set_single_plan_two_batch() { + // Test with a single plan that produces multiple record batches + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + let data1 = to_string_array(&["foo", "bar"]); + let batch1 = RecordBatch::try_new(Arc::clone(&schema), vec![data1]) + .expect("created new record batch"); + let data2 = to_string_array(&["baz", "foo"]); + let batch2 = RecordBatch::try_new(Arc::clone(&schema), vec![data2]) + .expect("created new record batch"); + let scan = make_plan(schema, vec![batch1, batch2]); + let plan: StringSetPlan = vec![scan].into(); + + let exec = Executor::new_testing(); + let ctx = exec.new_context(ExecutorType::Query); + let results = ctx.to_string_set(plan).await.unwrap(); + + assert_eq!(results, to_set(&["foo", "bar", "baz"])); + } + + #[tokio::test] + async fn executor_datafusion_string_set_multi_plan() { + // Test with multiple datafusion logical plans + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + + let data1 = to_string_array(&["foo", "bar"]); + let batch1 = RecordBatch::try_new(Arc::clone(&schema), vec![data1]) + .expect("created new record batch"); + let scan1 = make_plan(Arc::clone(&schema), vec![batch1]); + + let data2 = to_string_array(&["baz", "foo"]); + let batch2 = RecordBatch::try_new(Arc::clone(&schema), vec![data2]) + .expect("created new record batch"); + let scan2 = make_plan(schema, vec![batch2]); + + let plan: StringSetPlan = vec![scan1, scan2].into(); + + let exec = Executor::new_testing(); + let ctx = exec.new_context(ExecutorType::Query); + let results = ctx.to_string_set(plan).await.unwrap(); + + assert_eq!(results, to_set(&["foo", "bar", "baz"])); + } + + #[tokio::test] + async fn executor_datafusion_string_set_nulls() { + // Ensure that nulls in the output set are handled reasonably + // (error, rather than silently ignored) + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + let array = StringArray::from_iter(vec![Some("foo"), None]); + let data = Arc::new(array); + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data]) + .expect("created new record batch"); + let scan = make_plan(schema, vec![batch]); + let plan: StringSetPlan = vec![scan].into(); + + let exec = Executor::new_testing(); + let ctx = exec.new_context(ExecutorType::Query); + let results = ctx.to_string_set(plan).await; + + let actual_error = match results { + Ok(_) => "Unexpected Ok".into(), + Err(e) => format!("{e}"), + }; + let expected_error = "unexpected null value"; + assert!( + actual_error.contains(expected_error), + "expected error '{expected_error}' not found in '{actual_error:?}'", + ); + } + + #[tokio::test] + async fn executor_datafusion_string_set_bad_schema() { + // Ensure that an incorect schema (an int) gives a reasonable error + let data: ArrayRef = Arc::new(Int64Array::from(vec![1])); + let batch = + RecordBatch::try_from_iter(vec![("a", data)]).expect("created new record batch"); + let scan = make_plan(batch.schema(), vec![batch]); + let plan: StringSetPlan = vec![scan].into(); + + let exec = Executor::new_testing(); + let ctx = exec.new_context(ExecutorType::Query); + let results = ctx.to_string_set(plan).await; + + let actual_error = match results { + Ok(_) => "Unexpected Ok".into(), + Err(e) => format!("{e}"), + }; + + let expected_error = "schema not a single Utf8"; + assert!( + actual_error.contains(expected_error), + "expected error '{expected_error}' not found in '{actual_error:?}'" + ); + } + + #[tokio::test] + async fn make_schema_pivot_is_planned() { + // Test that all the planning logic is wired up and that we + // can make a plan using a SchemaPivot node + let batch = RecordBatch::try_from_iter_with_nullable(vec![ + ("f1", to_string_array(&["foo", "bar"]), true), + ("f2", to_string_array(&["baz", "bzz"]), true), + ]) + .expect("created new record batch"); + + let scan = make_plan(batch.schema(), vec![batch]); + let pivot = make_schema_pivot(scan); + let plan = vec![pivot].into(); + + let exec = Executor::new_testing(); + let ctx = exec.new_context(ExecutorType::Query); + let results = ctx.to_string_set(plan).await.expect("Executed plan"); + + assert_eq!(results, to_set(&["f1", "f2"])); + } + + #[tokio::test] + async fn test_metrics_integration() { + let exec = Executor::new_testing(); + + // start w/o any reservation + assert_eq!( + PoolMetrics::read(&exec.config.metric_registry), + PoolMetrics { + reserved: 0, + limit: TESTING_MEM_POOL_SIZE as u64, + }, + ); + + // block some reservation + let test_input = Arc::new(TestExec::default()); + let schema = test_input.schema(); + let plan = Arc::new(SortExec::new( + vec![PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema("c", &schema).unwrap()), + options: Default::default(), + }], + Arc::clone(&test_input) as _, + )); + let ctx = exec.new_context(ExecutorType::Query); + let handle = tokio::spawn(async move { + ctx.collect(plan).await.unwrap(); + }); + test_input.wait().await; + assert_eq!( + PoolMetrics::read(&exec.config.metric_registry), + PoolMetrics { + reserved: 896, + limit: TESTING_MEM_POOL_SIZE as u64, + }, + ); + test_input.wait_for_finish().await; + + // end w/o any reservation + handle.await.unwrap(); + assert_eq!( + PoolMetrics::read(&exec.config.metric_registry), + PoolMetrics { + reserved: 0, + limit: TESTING_MEM_POOL_SIZE as u64, + }, + ); + } + + /// return a set for testing + fn to_set(strs: &[&str]) -> StringSetRef { + StringSetRef::new(strs.iter().map(|s| s.to_string()).collect::()) + } + + fn to_string_array(strs: &[&str]) -> ArrayRef { + let array: StringArray = strs.iter().map(|s| Some(*s)).collect(); + Arc::new(array) + } + + // creates a DataFusion plan that reads the RecordBatches into memory + fn make_plan(schema: SchemaRef, data: Vec) -> LogicalPlan { + let partitions = vec![data]; + + let projection = None; + + // model one partition, + let table = MemTable::try_new(schema, partitions).unwrap(); + let source = provider_as_source(Arc::new(table)); + + LogicalPlanBuilder::scan("memtable", source, projection) + .unwrap() + .build() + .unwrap() + } + + #[derive(Debug)] + struct TestExec { + schema: SchemaRef, + // Barrier after a batch has been produced + barrier: Arc, + // Barrier right before the operator is complete + barrier_finish: Arc, + } + + impl Default for TestExec { + fn default() -> Self { + Self { + schema: Arc::new(arrow::datatypes::Schema::new(vec![Field::new( + "c", + DataType::Int64, + true, + )])), + barrier: Arc::new(Barrier::new(2)), + barrier_finish: Arc::new(Barrier::new(2)), + } + } + } + + impl TestExec { + /// wait for the first output to be produced + pub async fn wait(&self) { + self.barrier.wait().await; + } + + /// wait for output to be done + pub async fn wait_for_finish(&self) { + self.barrier_finish.wait().await; + } + } + + impl DisplayAs for TestExec { + fn fmt_as( + &self, + _t: datafusion::physical_plan::DisplayFormatType, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "TestExec") + } + } + + impl ExecutionPlan for TestExec { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn output_partitioning(&self) -> datafusion::physical_plan::Partitioning { + datafusion::physical_plan::Partitioning::UnknownPartitioning(1) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> datafusion::error::Result> { + unimplemented!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> datafusion::error::Result + { + let barrier = Arc::clone(&self.barrier); + let schema = Arc::clone(&self.schema); + let barrier_finish = Arc::clone(&self.barrier_finish); + let schema_finish = Arc::clone(&self.schema); + let stream = futures::stream::iter([Ok(RecordBatch::try_new( + Arc::clone(&self.schema), + vec![Arc::new(Int64Array::from(vec![1i64; 100]))], + ) + .unwrap())]) + .chain(futures::stream::once(async move { + barrier.wait().await; + Ok(RecordBatch::new_empty(schema)) + })) + .chain(futures::stream::once(async move { + barrier_finish.wait().await; + Ok(RecordBatch::new_empty(schema_finish)) + })); + let stream = BoxRecordBatchStream { + schema: Arc::clone(&self.schema), + inner: stream.boxed(), + }; + Ok(Box::pin(stream)) + } + + fn statistics(&self) -> Result { + Ok(datafusion::physical_plan::Statistics::new_unknown( + &self.schema(), + )) + } + } + + struct BoxRecordBatchStream { + schema: SchemaRef, + inner: BoxStream<'static, Result>, + } + + impl Stream for BoxRecordBatchStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = &mut *self; + this.inner.poll_next_unpin(cx) + } + } + + impl RecordBatchStream for BoxRecordBatchStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + } + + #[derive(Debug, PartialEq, Eq)] + struct PoolMetrics { + reserved: u64, + limit: u64, + } + + impl PoolMetrics { + fn read(registry: &Registry) -> Self { + let mut reporter = RawReporter::default(); + registry.report(&mut reporter); + let metric = reporter.metric("datafusion_mem_pool_bytes").unwrap(); + + let reserved = metric.observation(&[("state", "reserved")]).unwrap(); + let Observation::U64Gauge(reserved) = reserved else { + panic!("wrong metric type") + }; + let limit = metric.observation(&[("state", "limit")]).unwrap(); + let Observation::U64Gauge(limit) = limit else { + panic!("wrong metric type") + }; + + Self { + reserved: *reserved, + limit: *limit, + } + } + } +} diff --git a/iox_query/src/exec/context.rs b/iox_query/src/exec/context.rs new file mode 100644 index 0000000..ad60c7a --- /dev/null +++ b/iox_query/src/exec/context.rs @@ -0,0 +1,753 @@ +//! This module contains plumbing to connect InfluxDB IOx extensions to +//! DataFusion + +use super::{ + cross_rt_stream::CrossRtStream, + gapfill::{plan_gap_fill, GapFill}, + non_null_checker::NonNullCheckerNode, + seriesset::{series::Either, SeriesSet}, + sleep::SleepNode, + split::StreamSplitNode, +}; +use crate::{ + config::IoxConfigExt, + exec::{ + fieldlist::{FieldList, IntoFieldList}, + non_null_checker::NonNullCheckerExec, + query_tracing::TracedStream, + schema_pivot::{SchemaPivotExec, SchemaPivotNode}, + seriesset::{ + converter::{GroupGenerator, SeriesSetConverter}, + series::Series, + }, + split::StreamSplitExec, + stringset::{IntoStringSet, StringSetRef}, + }, + logical_optimizer::register_iox_logical_optimizers, + physical_optimizer::register_iox_physical_optimizers, + plan::{ + fieldlist::FieldListPlan, + seriesset::{SeriesSetPlan, SeriesSetPlans}, + stringset::StringSetPlan, + }, +}; +use arrow::record_batch::RecordBatch; +use async_trait::async_trait; +use datafusion::{ + catalog::CatalogProvider, + common::ParamValues, + execution::{ + context::{QueryPlanner, SessionState, TaskContext}, + memory_pool::MemoryPool, + runtime_env::RuntimeEnv, + }, + logical_expr::{LogicalPlan, UserDefinedLogicalNode}, + physical_plan::{ + coalesce_partitions::CoalescePartitionsExec, displayable, stream::RecordBatchStreamAdapter, + EmptyRecordBatchStream, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, + }, + physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}, + prelude::*, +}; +use datafusion_util::config::{iox_session_config, DEFAULT_CATALOG}; +use executor::DedicatedExecutor; +use futures::{Stream, StreamExt, TryStreamExt}; +use observability_deps::tracing::{debug, warn}; +use query_functions::{register_scalar_functions, selectors::register_selector_aggregates}; +use std::{fmt, num::NonZeroUsize, sync::Arc}; +use trace::{ + ctx::SpanContext, + span::{MetaValue, Span, SpanEvent, SpanExt, SpanRecorder}, +}; + +// Reuse DataFusion error and Result types for this module +pub use datafusion::error::{DataFusionError, Result}; + +/// This structure implements the DataFusion notion of "query planner" +/// and is needed to create plans with the IOx extension nodes. +struct IOxQueryPlanner {} + +#[async_trait] +impl QueryPlanner for IOxQueryPlanner { + /// Given a `LogicalPlan` created from above, create an + /// `ExecutionPlan` suitable for execution + async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + session_state: &SessionState, + ) -> Result> { + // Teach the default physical planner how to plan SchemaPivot + // and StreamSplit nodes. + let physical_planner = + DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new(IOxExtensionPlanner {})]); + // Delegate most work of physical planning to the default physical planner + physical_planner + .create_physical_plan(logical_plan, session_state) + .await + } +} + +/// Physical planner for InfluxDB IOx extension plans +struct IOxExtensionPlanner {} + +#[async_trait] +impl ExtensionPlanner for IOxExtensionPlanner { + /// Create a physical plan for an extension node + async fn plan_extension( + &self, + planner: &dyn PhysicalPlanner, + node: &dyn UserDefinedLogicalNode, + logical_inputs: &[&LogicalPlan], + physical_inputs: &[Arc], + session_state: &SessionState, + ) -> Result>> { + let any = node.as_any(); + let plan = if let Some(schema_pivot) = any.downcast_ref::() { + assert_eq!(physical_inputs.len(), 1, "Inconsistent number of inputs"); + Some(Arc::new(SchemaPivotExec::new( + Arc::clone(&physical_inputs[0]), + schema_pivot.schema().as_ref().clone().into(), + )) as Arc) + } else if let Some(non_null_checker) = any.downcast_ref::() { + assert_eq!(physical_inputs.len(), 1, "Inconsistent number of inputs"); + Some(Arc::new(NonNullCheckerExec::new( + Arc::clone(&physical_inputs[0]), + non_null_checker.schema().as_ref().clone().into(), + non_null_checker.value(), + )) as Arc) + } else if let Some(stream_split) = any.downcast_ref::() { + assert_eq!( + logical_inputs.len(), + 1, + "Inconsistent number of logical inputs" + ); + assert_eq!( + physical_inputs.len(), + 1, + "Inconsistent number of physical inputs" + ); + + let split_exprs = stream_split + .split_exprs() + .iter() + .map(|e| { + planner.create_physical_expr( + e, + logical_inputs[0].schema(), + &physical_inputs[0].schema(), + session_state, + ) + }) + .collect::>>()?; + + Some(Arc::new(StreamSplitExec::new( + Arc::clone(&physical_inputs[0]), + split_exprs, + )) as Arc) + } else if let Some(gap_fill) = any.downcast_ref::() { + let gap_fill_exec = plan_gap_fill( + session_state.execution_props(), + gap_fill, + logical_inputs, + physical_inputs, + )?; + Some(Arc::new(gap_fill_exec) as Arc) + } else if let Some(sleep) = any.downcast_ref::() { + let sleep = sleep.plan(planner, logical_inputs, physical_inputs, session_state)?; + Some(Arc::new(sleep) as _) + } else { + None + }; + Ok(plan) + } +} + +/// Configuration for an IOx execution context +/// +/// Created from an Executor +#[derive(Clone)] +pub struct IOxSessionConfig { + /// Executor to run on + exec: DedicatedExecutor, + + /// DataFusion session configuration + session_config: SessionConfig, + + /// Shared DataFusion runtime + runtime: Arc, + + /// Default catalog + default_catalog: Option>, + + /// Span context from which to create spans for this query + span_ctx: Option, +} + +impl fmt::Debug for IOxSessionConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "IOxSessionConfig ...") + } +} + +impl IOxSessionConfig { + pub(super) fn new(exec: DedicatedExecutor, runtime: Arc) -> Self { + let mut session_config = iox_session_config(); + session_config + .options_mut() + .extensions + .insert(IoxConfigExt::default()); + + Self { + exec, + session_config, + runtime, + default_catalog: None, + span_ctx: None, + } + } + + /// Set execution concurrency + pub fn with_target_partitions(mut self, target_partitions: NonZeroUsize) -> Self { + self.session_config = self + .session_config + .with_target_partitions(target_partitions.get()); + self + } + + /// Set the default catalog provider + pub fn with_default_catalog(self, catalog: Arc) -> Self { + Self { + default_catalog: Some(catalog), + ..self + } + } + + /// Set the span context from which to create distributed tracing spans for this query + pub fn with_span_context(self, span_ctx: Option) -> Self { + Self { span_ctx, ..self } + } + + /// Set DataFusion [config option]. + /// + /// May be used to set [IOx-specific] option as well. + /// + /// + /// [config option]: datafusion::common::config::ConfigOptions + /// [IOx-specific]: crate::config::IoxConfigExt + pub fn with_config_option(mut self, key: &str, value: &str) -> Self { + // ignore invalid config + if let Err(e) = self.session_config.options_mut().set(key, value) { + warn!( + key, + value, + %e, + "invalid DataFusion config", + ); + } + self + } + + /// Create an ExecutionContext suitable for executing DataFusion plans + pub fn build(self) -> IOxSessionContext { + let maybe_span = self.span_ctx.child_span("Query Execution"); + let recorder = SpanRecorder::new(maybe_span); + + // attach span to DataFusion session + let session_config = self + .session_config + .with_extension(Arc::new(recorder.span().cloned())); + + let state = SessionState::new_with_config_rt(session_config, self.runtime) + .with_query_planner(Arc::new(IOxQueryPlanner {})); + let state = register_iox_physical_optimizers(state); + let state = register_iox_logical_optimizers(state); + + let inner = SessionContext::new_with_state(state); + register_selector_aggregates(&inner); + register_scalar_functions(&inner); + if let Some(default_catalog) = self.default_catalog { + inner.register_catalog(DEFAULT_CATALOG, default_catalog); + } + + IOxSessionContext::new(inner, self.exec, recorder) + } +} + +/// This is an execution context for planning in IOx. It wraps a +/// DataFusion execution context with the information needed for planning. +/// +/// Methods on this struct should be preferred to using the raw +/// DataFusion functions (such as `collect`) directly. +/// +/// Eventually we envision this also managing additional resource +/// types such as Memory and providing visibility into what plans are +/// running +/// +/// An IOxSessionContext is created directly from an Executor, or from +/// an IOxSessionConfig created by an Executor +pub struct IOxSessionContext { + inner: SessionContext, + + /// Dedicated executor for query execution. + /// + /// DataFusion plans are "CPU" bound and thus can consume tokio + /// executors threads for extended periods of time. We use a + /// dedicated tokio runtime to run them so that other requests + /// can be handled. + exec: DedicatedExecutor, + + /// Span context from which to create spans for this query + recorder: SpanRecorder, +} + +impl fmt::Debug for IOxSessionContext { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("IOxSessionContext") + .field("inner", &"") + .field("exec", &self.exec) + .field("recorder", &self.recorder) + .finish() + } +} + +impl IOxSessionContext { + /// Constructor for testing. + /// + /// This is identical to [`Default::default`] but we do NOT implement [`Default`] to make the creation of untracked + /// contexts more explicit. + pub fn with_testing() -> Self { + Self { + inner: SessionContext::default(), + exec: DedicatedExecutor::new_testing(), + recorder: SpanRecorder::default(), + } + } + + /// Private constructor + pub(crate) fn new( + inner: SessionContext, + exec: DedicatedExecutor, + recorder: SpanRecorder, + ) -> Self { + Self { + inner, + exec, + recorder, + } + } + + /// returns a reference to the inner datafusion execution context + pub fn inner(&self) -> &SessionContext { + &self.inner + } + + /// Plan a SQL statement. This assumes that any tables referenced + /// in the SQL have been registered with this context. Use + /// `create_physical_plan` to actually execute the query. + pub async fn sql_to_logical_plan(&self, sql: &str) -> Result { + Self::sql_to_logical_plan_with_params(self, sql, ParamValues::List(vec![])).await + } + + /// Plan a SQL statement, providing a list of parameter values + /// to supply to `$placeholder` variables. This assumes that + /// any tables referenced in the SQL have been registered with + /// this context. Use `create_physical_plan` to actually execute + /// the query. + pub async fn sql_to_logical_plan_with_params( + &self, + sql: &str, + params: impl Into + Send, + ) -> Result { + let ctx = self.child_ctx("sql_to_logical_plan"); + debug!(text=%sql, "planning SQL query"); + let plan = ctx + .inner + .state() + .create_logical_plan(sql) + .await? + .with_param_values(params.into())?; + // ensure the plan does not contain unwanted statements + let verifier = SQLOptions::new() + .with_allow_ddl(false) // no CREATE ... + .with_allow_dml(false) // no INSERT or COPY + .with_allow_statements(false); // no SET VARIABLE, etc + verifier.verify_plan(&plan)?; + Ok(plan) + } + + /// Create a logical plan that reads a single [`RecordBatch`]. Use + /// `create_physical_plan` to actually execute the query. + pub fn batch_to_logical_plan(&self, batch: RecordBatch) -> Result { + let ctx = self.child_ctx("batch_to_logical_plan"); + debug!(num_rows = batch.num_rows(), "planning RecordBatch query"); + ctx.inner.read_batch(batch)?.into_optimized_plan() + } + + /// Plan a SQL statement and convert it to an execution plan. This assumes that any + /// tables referenced in the SQL have been registered with this context + pub async fn sql_to_physical_plan(&self, sql: &str) -> Result> { + Self::sql_to_physical_plan_with_params(self, sql, ParamValues::List(vec![])).await + } + + /// Plan a SQL statement and convert it to an execution plan, providing a list of + /// parameter values to supply to `$placeholder` variables. This assumes that any + /// tables referenced in the SQL have been registered with this context + pub async fn sql_to_physical_plan_with_params( + &self, + sql: &str, + params: impl Into + Send, + ) -> Result> { + let ctx = self.child_ctx("sql_to_physical_plan"); + + let logical_plan = ctx.sql_to_logical_plan_with_params(sql, params).await?; + ctx.create_physical_plan(&logical_plan).await + } + + /// Prepare (optimize + plan) a pre-created [`LogicalPlan`] for execution + pub async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + ) -> Result> { + let mut ctx = self.child_ctx("create_physical_plan"); + debug!(text=%logical_plan.display_indent_schema(), "create_physical_plan: initial plan"); + let physical_plan = ctx.inner.state().create_physical_plan(logical_plan).await?; + + ctx.recorder.event(SpanEvent::new("physical plan")); + debug!(text=%displayable(physical_plan.as_ref()).indent(false), "create_physical_plan: plan to run"); + Ok(physical_plan) + } + + /// Executes the logical plan using DataFusion on a separate + /// thread pool and produces RecordBatches + pub async fn collect(&self, physical_plan: Arc) -> Result> { + debug!( + "Running plan, physical:\n{}", + displayable(physical_plan.as_ref()).indent(false) + ); + let ctx = self.child_ctx("collect"); + let stream = ctx.execute_stream(physical_plan).await?; + + ctx.run( + stream + .err_into() // convert to DataFusionError + .try_collect(), + ) + .await + } + + /// Executes the physical plan and produces a + /// `SendableRecordBatchStream` to stream over the result that + /// iterates over the results. The creation of the stream is + /// performed in a separate thread pool. + pub async fn execute_stream( + &self, + physical_plan: Arc, + ) -> Result { + match physical_plan.output_partitioning().partition_count() { + 0 => Ok(Box::pin(EmptyRecordBatchStream::new( + physical_plan.schema(), + ))), + 1 => self.execute_stream_partitioned(physical_plan, 0).await, + _ => { + // Merge into a single partition + self.execute_stream_partitioned( + Arc::new(CoalescePartitionsExec::new(physical_plan)), + 0, + ) + .await + } + } + } + + /// Executes a single partition of a physical plan and produces a + /// `SendableRecordBatchStream` to stream over the result that + /// iterates over the results. The creation of the stream is + /// performed in a separate thread pool. + pub async fn execute_stream_partitioned( + &self, + physical_plan: Arc, + partition: usize, + ) -> Result { + let span = self + .recorder + .span() + .map(|span| span.child("execute_stream_partitioned")); + + let task_context = Arc::new(TaskContext::from(self.inner())); + + let stream = self + .run(async move { + let stream = physical_plan.execute(partition, task_context)?; + Ok(TracedStream::new(stream, span, physical_plan)) + }) + .await?; + // Wrap the resulting stream into `CrossRtStream`. This is required because polling the DataFusion result stream + // actually drives the (potentially CPU-bound) work. We need to make sure that this work stays within the + // dedicated executor because otherwise this may block the top-level tokio/tonic runtime which may lead to + // requests timetouts (either for new requests, metrics or even for HTTP2 pings on the active connection). + let schema = stream.schema(); + let stream = CrossRtStream::new_with_df_error_stream(stream, self.exec.clone()); + let stream = RecordBatchStreamAdapter::new(schema, stream); + Ok(Box::pin(stream)) + } + + /// Executes the SeriesSetPlans on the query executor, in + /// parallel, producing series or groups + pub async fn to_series_and_groups( + &self, + series_set_plans: SeriesSetPlans, + memory_pool: Arc, + points_per_batch: usize, + ) -> Result>> { + let SeriesSetPlans { + mut plans, + group_columns, + } = series_set_plans; + + if plans.is_empty() { + return Ok(futures::stream::empty().boxed()); + } + + // sort plans by table (measurement) name + plans.sort_by(|a, b| a.table_name.cmp(&b.table_name)); + + // Run the plans in parallel + let ctx = self.child_ctx("to_series_set"); + let exec = self.exec.clone(); + let data = futures::stream::iter(plans) + .then(move |plan| { + let ctx = ctx.child_ctx("for plan"); + let exec = exec.clone(); + + async move { + let stream = Self::run_inner(exec.clone(), async move { + let SeriesSetPlan { + table_name, + plan, + tag_columns, + field_columns, + } = plan; + + let tag_columns = Arc::new(tag_columns); + + let physical_plan = ctx.create_physical_plan(&plan).await?; + + let it = ctx.execute_stream(physical_plan).await?; + + SeriesSetConverter::default() + .convert(table_name, tag_columns, field_columns, it) + .await + }) + .await?; + + Ok::<_, DataFusionError>(CrossRtStream::new_with_df_error_stream(stream, exec)) + } + }) + .try_flatten() + .try_filter_map(move |series_set: SeriesSet| async move { + // If all timestamps of returned columns are nulls, + // there must be no data. We need to check this because + // aggregate (e.g. count, min, max) returns one row that are + // all null (even the values of aggregate) for min, max and 0 for count. + // For influx read_group's series and group, we do not want to return 0 + // for count either. + if series_set.is_timestamp_all_null() { + return Ok(None); + } + + let series: Vec = + series_set.try_into_series(points_per_batch).map_err(|e| { + DataFusionError::Execution(format!("Error converting to series: {e}")) + })?; + Ok(Some(futures::stream::iter(series).map(Ok))) + }) + .try_flatten(); + + // If we have group columns, sort the results, and create the + // appropriate groups + if let Some(group_columns) = group_columns { + let grouper = GroupGenerator::new(group_columns, memory_pool); + Ok(grouper.group(data).await?.boxed()) + } else { + Ok(data.map_ok(|series| series.into()).boxed()) + } + } + + /// Executes `plan` and return the resulting FieldList on the query executor + pub async fn to_field_list(&self, plan: FieldListPlan) -> Result { + let FieldListPlan { + known_values, + extra_plans, + } = plan; + + // Run the plans in parallel + let handles = extra_plans + .into_iter() + .map(|plan| { + let ctx = self.child_ctx("to_field_list"); + self.run(async move { + let physical_plan = ctx.create_physical_plan(&plan).await?; + + // TODO: avoid this buffering + let field_list = + ctx.collect(physical_plan) + .await? + .into_fieldlist() + .map_err(|e| { + DataFusionError::Context( + "Error converting to field list".to_string(), + Box::new(DataFusionError::External(Box::new(e))), + ) + })?; + + Ok(field_list) + }) + }) + .collect::>(); + + // collect them all up and combine them + let mut results = Vec::new(); + + if !known_values.is_empty() { + let list = known_values.into_iter().map(|f| f.1).collect(); + results.push(FieldList { fields: list }) + } + + for join_handle in handles { + let fieldlist = join_handle.await?; + + results.push(fieldlist); + } + + // TODO: Stream this + results.into_fieldlist().map_err(|e| { + DataFusionError::Context( + "Error converting to field list".to_string(), + Box::new(DataFusionError::External(Box::new(e))), + ) + }) + } + + /// Executes this plan on the query pool, and returns the + /// resulting set of strings + pub async fn to_string_set(&self, plan: StringSetPlan) -> Result { + let ctx = self.child_ctx("to_string_set"); + match plan { + StringSetPlan::Known(ss) => Ok(ss), + StringSetPlan::Plan(plans) => ctx + .run_logical_plans(plans) + .await? + .into_stringset() + .map_err(|e| { + DataFusionError::Context( + "Error converting to stringset".to_string(), + Box::new(DataFusionError::External(Box::new(e))), + ) + }), + } + } + + /// plans and runs the plans in parallel and collects the results + /// run each plan in parallel and collect the results + async fn run_logical_plans(&self, plans: Vec) -> Result> { + let value_futures = plans + .into_iter() + .map(|plan| { + let ctx = self.child_ctx("run_logical_plans"); + self.run(async move { + let physical_plan = ctx.create_physical_plan(&plan).await?; + + // TODO: avoid this buffering + ctx.collect(physical_plan).await + }) + }) + .collect::>(); + + // now, wait for all the values to resolve and collect them together + let mut results = Vec::new(); + for join_handle in value_futures { + let mut plan_result = join_handle.await?; + results.append(&mut plan_result); + } + Ok(results) + } + + /// Runs the provided future using this execution context + pub async fn run(&self, fut: Fut) -> Result + where + Fut: std::future::Future> + Send + 'static, + T: Send + 'static, + { + Self::run_inner(self.exec.clone(), fut).await + } + + async fn run_inner(exec: DedicatedExecutor, fut: Fut) -> Result + where + Fut: std::future::Future> + Send + 'static, + T: Send + 'static, + { + exec.spawn(fut).await.unwrap_or_else(|e| { + Err(DataFusionError::Context( + "Join Error".to_string(), + Box::new(DataFusionError::External(Box::new(e))), + )) + }) + } + + /// Returns a IOxSessionContext with a SpanRecorder that is a child of the current + pub fn child_ctx(&self, name: &'static str) -> Self { + Self::new( + self.inner.clone(), + self.exec.clone(), + self.recorder.child(name), + ) + } + + /// Record an event on the span recorder + pub fn record_event(&mut self, name: &'static str) { + self.recorder.event(SpanEvent::new(name)); + } + + /// Record an event on the span recorder + pub fn set_metadata(&mut self, name: &'static str, value: impl Into) { + self.recorder.set_metadata(name, value); + } + + /// Returns the current [`Span`] if any + pub fn span(&self) -> Option<&Span> { + self.recorder.span() + } + + /// Returns a new child span of the current context + pub fn child_span(&self, name: &'static str) -> Option { + self.recorder.child_span(name) + } + + /// Number of currently active tasks. + pub fn tasks(&self) -> usize { + self.exec.tasks() + } +} + +/// Extension trait to pull IOx spans out of DataFusion contexts. +pub trait SessionContextIOxExt { + /// Get child span of the current context. + fn child_span(&self, name: &'static str) -> Option; + + /// Get span context + fn span_ctx(&self) -> Option; +} + +impl SessionContextIOxExt for SessionState { + fn child_span(&self, name: &'static str) -> Option { + self.config() + .get_extension::>() + .and_then(|span| span.as_ref().as_ref().map(|span| span.child(name))) + } + + fn span_ctx(&self) -> Option { + self.config() + .get_extension::>() + .and_then(|span| span.as_ref().as_ref().map(|span| span.ctx.clone())) + } +} diff --git a/iox_query/src/exec/cross_rt_stream.rs b/iox_query/src/exec/cross_rt_stream.rs new file mode 100644 index 0000000..c5303cb --- /dev/null +++ b/iox_query/src/exec/cross_rt_stream.rs @@ -0,0 +1,357 @@ +//! Tooling to pull [`Stream`]s from one tokio runtime into another. +//! +//! This is critical so that CPU heavy loads are not run on the same runtime as IO handling +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use datafusion::error::DataFusionError; +use executor::DedicatedExecutor; +use futures::{future::BoxFuture, ready, FutureExt, Stream, StreamExt}; +use tokio::sync::mpsc::{channel, Sender}; +use tokio_stream::wrappers::ReceiverStream; + +/// [`Stream`] that is calculated by one tokio runtime but can safely be pulled from another w/o stalling (esp. when the +/// calculating runtime is CPU-blocked). +pub struct CrossRtStream { + /// Future that drives the underlying stream. + /// + /// This is actually wrapped into [`DedicatedExecutor::spawn`] so it can be safely polled by the receiving runtime. + driver: BoxFuture<'static, ()>, + + /// Flags if the [driver](Self::driver) returned [`Poll::Ready`]. + driver_ready: bool, + + /// Receiving stream. + /// + /// This one can be polled from the receiving runtime. + inner: ReceiverStream, + + /// Signals that [`inner`](Self::inner) finished. + /// + /// Note that we must also drive the [driver](Self::driver) even when the stream finished to allow proper state clean-ups. + inner_done: bool, +} + +impl std::fmt::Debug for CrossRtStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CrossRtStream") + .field("driver", &"...") + .field("driver_ready", &self.driver_ready) + .field("inner", &"...") + .field("inner_done", &self.inner_done) + .finish() + } +} + +impl CrossRtStream { + /// Create new stream by producing a future that sends its state to the given [`Sender`]. + /// + /// This is an internal method. `f` should always be wrapped into [`DedicatedExecutor::spawn`] (except for testing purposes). + fn new_with_tx(f: F) -> Self + where + F: FnOnce(Sender) -> Fut, + Fut: Future + Send + 'static, + { + let (tx, rx) = channel(1); + let driver = f(tx).boxed(); + Self { + driver, + driver_ready: false, + inner: ReceiverStream::new(rx), + inner_done: false, + } + } +} + +impl CrossRtStream> +where + X: Send + 'static, + E: Send + 'static, +{ + /// Create new stream based on an existing stream that transports [`Result`]s. + /// + /// Also receives an executor that actually executes the underlying stream as well as a converter that convets + /// [`executor::JobError`] to the error type of the stream (so we can send potential crashes/panics). + fn new_with_error_stream(stream: S, exec: DedicatedExecutor, converter: C) -> Self + where + S: Stream> + Send + 'static, + C: Fn(executor::JobError) -> E + Send + 'static, + { + Self::new_with_tx(|tx| { + // future to be run in the other runtime + let tx_captured = tx.clone(); + let fut = async move { + tokio::pin!(stream); + + while let Some(res) = stream.next().await { + if tx_captured.send(res).await.is_err() { + // receiver gone + return; + } + } + }; + + // future for this runtime (likely the tokio/tonic/web driver) + async move { + if let Err(e) = exec.spawn(fut).await { + let e = converter(e); + + // last message, so we don't care about the receiver side + tx.send(Err(e)).await.ok(); + } + } + }) + } +} + +impl CrossRtStream> +where + X: Send + 'static, +{ + /// Create new stream based on an existing stream that transports [`Result`]s w/ [`DataFusionError`]s. + /// + /// Also receives an executor that actually executes the underlying stream. + pub fn new_with_df_error_stream(stream: S, exec: DedicatedExecutor) -> Self + where + S: Stream> + Send + 'static, + { + Self::new_with_error_stream(stream, exec, |e| { + DataFusionError::Context( + "Join Error (panic)".to_string(), + Box::new(DataFusionError::External(e.into())), + ) + }) + } +} + +impl Stream for CrossRtStream { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = &mut *self; + + if !this.driver_ready { + let res = this.driver.poll_unpin(cx); + + if res.is_ready() { + this.driver_ready = true; + } + } + + if this.inner_done { + if this.driver_ready { + Poll::Ready(None) + } else { + Poll::Pending + } + } else { + match ready!(this.inner.poll_next_unpin(cx)) { + None => { + this.inner_done = true; + if this.driver_ready { + Poll::Ready(None) + } else { + Poll::Pending + } + } + Some(x) => Poll::Ready(Some(x)), + } + } + } +} + +#[cfg(test)] +mod tests { + use std::{sync::Arc, time::Duration}; + + use super::*; + use tokio::runtime::{Handle, RuntimeFlavor}; + + #[tokio::test] + async fn test_async_block() { + let exec = DedicatedExecutor::new_testing(); + let barrier1 = Arc::new(tokio::sync::Barrier::new(2)); + let barrier1_captured = Arc::clone(&barrier1); + let barrier2 = Arc::new(tokio::sync::Barrier::new(2)); + let barrier2_captured = Arc::clone(&barrier2); + let mut stream = CrossRtStream::>::new_with_error_stream( + futures::stream::once(async move { + barrier1_captured.wait().await; + barrier2_captured.wait().await; + Ok(1) + }), + exec, + std::convert::identity, + ); + + let mut f = stream.next(); + + ensure_pending(&mut f).await; + barrier1.wait().await; + ensure_pending(&mut f).await; + barrier2.wait().await; + + let res = f.await.expect("streamed data"); + assert_eq!(res.unwrap(), 1); + } + + #[tokio::test] + async fn test_sync_block() { + // This would deadlock if the stream payload would run within the same tokio runtime. To prevent any cheating + // (e.g. via channels), we ensure that the current runtime only has a single thread: + assert_eq!( + RuntimeFlavor::CurrentThread, + Handle::current().runtime_flavor() + ); + + let exec = DedicatedExecutor::new_testing(); + let barrier1 = Arc::new(std::sync::Barrier::new(2)); + let barrier1_captured = Arc::clone(&barrier1); + let barrier2 = Arc::new(std::sync::Barrier::new(2)); + let barrier2_captured = Arc::clone(&barrier2); + let mut stream = CrossRtStream::>::new_with_error_stream( + futures::stream::once(async move { + barrier1_captured.wait(); + barrier2_captured.wait(); + Ok(1) + }), + exec, + std::convert::identity, + ); + + let mut f = stream.next(); + + ensure_pending(&mut f).await; + barrier1.wait(); + ensure_pending(&mut f).await; + barrier2.wait(); + + let res = f.await.expect("streamed data"); + assert_eq!(res.unwrap(), 1); + } + + #[tokio::test] + async fn test_panic() { + let exec = DedicatedExecutor::new_testing(); + let mut stream = CrossRtStream::>::new_with_error_stream( + futures::stream::once(async { panic!("foo") }), + exec, + std::convert::identity, + ); + + let e = stream + .next() + .await + .expect("stream not finished") + .unwrap_err(); + assert_eq!(e.to_string(), "Panic: foo"); + + let none = stream.next().await; + assert!(none.is_none()); + } + + #[tokio::test] + async fn test_cancel_future() { + let exec = DedicatedExecutor::new_testing(); + let barrier1 = Arc::new(tokio::sync::Barrier::new(2)); + let barrier1_captured = Arc::clone(&barrier1); + let barrier2 = Arc::new(tokio::sync::Barrier::new(2)); + let barrier2_captured = Arc::clone(&barrier2); + let mut stream = CrossRtStream::>::new_with_error_stream( + futures::stream::once(async move { + barrier1_captured.wait().await; + barrier2_captured.wait().await; + Ok(1) + }), + exec, + std::convert::identity, + ); + + let mut f = stream.next(); + + // fire up stream + ensure_pending(&mut f).await; + barrier1.wait().await; + + // cancel + drop(f); + + barrier2.wait().await; + let res = stream.next().await.expect("streamed data"); + assert_eq!(res.unwrap(), 1); + } + + #[tokio::test] + async fn test_cancel_stream() { + let exec = DedicatedExecutor::new_testing(); + let barrier = Arc::new(tokio::sync::Barrier::new(2)); + let barrier_captured = Arc::clone(&barrier); + let mut stream = CrossRtStream::>::new_with_error_stream( + futures::stream::once(async move { + barrier_captured.wait().await; + + // block forever + futures::future::pending::<()>().await; + + // keep barrier Arc alive + drop(barrier_captured); + unreachable!() + }), + exec, + std::convert::identity, + ); + + let mut f = stream.next(); + + // fire up stream + ensure_pending(&mut f).await; + barrier.wait().await; + assert_eq!(Arc::strong_count(&barrier), 2); + + // cancel + drop(f); + drop(stream); + + tokio::time::timeout(Duration::from_secs(5), async { + loop { + if Arc::strong_count(&barrier) == 1 { + return; + } + + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_inner_future_driven_to_completion_after_stream_ready() { + let barrier = Arc::new(tokio::sync::Barrier::new(2)); + let barrier_captured = Arc::clone(&barrier); + + let mut stream = CrossRtStream::::new_with_tx(|tx| async move { + tx.send(1).await.ok(); + drop(tx); + barrier_captured.wait().await; + }); + + let handle = tokio::spawn(async move { barrier.wait().await }); + + assert_eq!(stream.next().await, Some(1)); + handle.await.unwrap(); + } + + async fn ensure_pending(f: &mut F) + where + F: Future + Send + Unpin, + { + tokio::select! { + _ = tokio::time::sleep(Duration::from_millis(100)) => {} + _ = f => {panic!("not pending")}, + } + } +} diff --git a/iox_query/src/exec/field.rs b/iox_query/src/exec/field.rs new file mode 100644 index 0000000..5838890 --- /dev/null +++ b/iox_query/src/exec/field.rs @@ -0,0 +1,182 @@ +use std::sync::Arc; + +use arrow::{self, datatypes::SchemaRef}; +use schema::TIME_COLUMN_NAME; +use snafu::{ResultExt, Snafu}; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("Error finding field column: {:?} in schema '{}'", column_name, source))] + ColumnNotFoundForField { + column_name: String, + source: arrow::error::ArrowError, + }, +} + +pub type Result = std::result::Result; + +/// Names for a field: a value field and the associated timestamp columns +#[derive(Debug, PartialEq, Eq)] +pub enum FieldColumns { + /// All field columns share a timestamp column, named TIME_COLUMN_NAME + SharedTimestamp(Vec>), + + /// Each field has a potentially different timestamp column + // (value_name, timestamp_name) + DifferentTimestamp(Vec<(Arc, Arc)>), +} + +impl From>> for FieldColumns { + fn from(v: Vec>) -> Self { + Self::SharedTimestamp(v) + } +} + +impl From, Arc)>> for FieldColumns { + fn from(v: Vec<(Arc, Arc)>) -> Self { + Self::DifferentTimestamp(v) + } +} + +impl From> for FieldColumns { + fn from(v: Vec<&str>) -> Self { + let v = v.into_iter().map(Arc::from).collect(); + + Self::SharedTimestamp(v) + } +} + +impl From<&[&str]> for FieldColumns { + fn from(v: &[&str]) -> Self { + let v = v.iter().map(|v| Arc::from(*v)).collect(); + + Self::SharedTimestamp(v) + } +} + +/// Column indexes for a field: a value and corresponding timestamp +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct FieldIndex { + pub value_index: usize, + pub timestamp_index: usize, +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct FieldIndexes { + inner: Arc>, +} + +impl FieldIndexes { + /// Create FieldIndexes where each field has the same timestamp + /// and different value index + pub fn from_timestamp_and_value_indexes( + timestamp_index: usize, + value_indexes: &[usize], + ) -> Self { + value_indexes + .iter() + .map(|&value_index| FieldIndex { + value_index, + timestamp_index, + }) + .collect::>() + .into() + } + + /// Convert a slice of pairs (value_index, time_index) into + /// FieldIndexes + pub fn from_slice(v: &[(usize, usize)]) -> Self { + let inner = v + .iter() + .map(|&(value_index, timestamp_index)| FieldIndex { + value_index, + timestamp_index, + }) + .collect(); + + Self { + inner: Arc::new(inner), + } + } + + pub fn as_slice(&self) -> &[FieldIndex] { + self.inner.as_ref() + } + + pub fn iter(&self) -> impl Iterator { + self.as_slice().iter() + } +} + +impl From> for FieldIndexes { + fn from(list: Vec) -> Self { + Self { + inner: Arc::new(list), + } + } +} + +impl FieldIndexes { + // look up which column index correponds to each column name + pub fn names_to_indexes(schema: &SchemaRef, column_names: &[Arc]) -> Result> { + column_names + .iter() + .map(|column_name| { + schema + .index_of(column_name) + .context(ColumnNotFoundForFieldSnafu { + column_name: column_name.as_ref(), + }) + }) + .collect() + } + + /// Translate the field columns into pairs of (field_index, timestamp_index) + pub fn from_field_columns(schema: &SchemaRef, field_columns: &FieldColumns) -> Result { + let indexes = match field_columns { + FieldColumns::SharedTimestamp(field_names) => { + let timestamp_index = + schema + .index_of(TIME_COLUMN_NAME) + .context(ColumnNotFoundForFieldSnafu { + column_name: TIME_COLUMN_NAME, + })?; + + Self::names_to_indexes(schema, field_names)? + .into_iter() + .map(|field_index| FieldIndex { + value_index: field_index, + timestamp_index, + }) + .collect::>() + .into() + } + FieldColumns::DifferentTimestamp(fields_and_timestamp_names) => { + fields_and_timestamp_names + .iter() + .map(|(field_name, timestamp_name)| { + let field_index = + schema + .index_of(field_name) + .context(ColumnNotFoundForFieldSnafu { + column_name: field_name.as_ref(), + })?; + + let timestamp_index = schema.index_of(timestamp_name).context( + ColumnNotFoundForFieldSnafu { + column_name: TIME_COLUMN_NAME, + }, + )?; + + Ok(FieldIndex { + value_index: field_index, + timestamp_index, + }) + }) + .collect::>>()? + .into() + } + }; + Ok(indexes) + } +} diff --git a/iox_query/src/exec/fieldlist.rs b/iox_query/src/exec/fieldlist.rs new file mode 100644 index 0000000..e749543 --- /dev/null +++ b/iox_query/src/exec/fieldlist.rs @@ -0,0 +1,433 @@ +//! This module contains the definition of a "FieldList" a set of +//! records of (field_name, field_type, last_timestamp) and code to +//! pull them from RecordBatches +use std::{collections::BTreeMap, sync::Arc}; + +use arrow::{ + self, + array::TimestampNanosecondArray, + datatypes::{DataType, SchemaRef}, + record_batch::RecordBatch, +}; +use schema::TIME_COLUMN_NAME; + +use snafu::{ensure, ResultExt, Snafu}; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display( + "Internal error converting to FieldList. No time column in schema: {:?}. {}", + schema, + source + ))] + InternalNoTimeColumn { + schema: SchemaRef, + source: arrow::error::ArrowError, + }, + + #[snafu(display( + "Inconsistent data type for field '{}': found both '{:?}' and '{:?}'", + field_name, + data_type1, + data_type2 + ))] + InconsistentFieldType { + field_name: String, + data_type1: DataType, + data_type2: DataType, + }, +} + +pub type Result = std::result::Result; + +/// Represents a single Field (column)'s metadata: Name, data_type, +/// and most recent (last) timestamp. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Field { + pub name: String, + pub data_type: DataType, + pub last_timestamp: i64, +} + +/// A list of `Fields` +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct FieldList { + pub fields: Vec, +} + +/// Trait to convert RecordBatch'y things into `FieldLists`. Assumes +/// that the input RecordBatch es can each have a single string +/// column. +pub trait IntoFieldList { + /// Convert this thing into a fieldlist + fn into_fieldlist(self) -> Result; +} + +/// Converts record batches into FieldLists +impl IntoFieldList for Vec { + fn into_fieldlist(self) -> Result { + if self.is_empty() { + return Ok(FieldList::default()); + } + + // For each field in the schema (except time) for all rows + // that are non-null, update the current most-recent timestamp + // seen + let arrow_schema = self[0].schema(); + + let time_column_index = arrow_schema.index_of(TIME_COLUMN_NAME).with_context(|_| { + InternalNoTimeColumnSnafu { + schema: Arc::clone(&arrow_schema), + } + })?; + + // key: fieldname, value: highest value of time column we have seen + let mut field_times = BTreeMap::new(); + + for batch in self { + let time_column = batch + .column(time_column_index) + .as_any() + .downcast_ref::() + .expect("Downcasting time to TimestampNanosecondArray"); + + for (column_index, arrow_field) in arrow_schema.fields().iter().enumerate() { + if column_index == time_column_index { + continue; + } + let array = batch.column(column_index); + + // walk each value in array, looking for non-null values + let mut max_ts: Option = None; + for i in 0..batch.num_rows() { + if !array.is_null(i) { + let cur_ts = time_column.value(i); + max_ts = max_ts.map(|ts| std::cmp::max(ts, cur_ts)).or(Some(cur_ts)); + } + } + + if let Some(max_ts) = max_ts { + if let Some(ts) = field_times.get_mut(arrow_field.name()) { + *ts = std::cmp::max(max_ts, *ts); + } else { + field_times.insert(arrow_field.name().to_string(), max_ts); + } + } + } + } + + let fields = arrow_schema + .fields() + .iter() + .filter_map(|arrow_field| { + let field_name = arrow_field.name(); + if field_name == TIME_COLUMN_NAME { + None + } else { + field_times.get(field_name).map(|ts| Field { + name: field_name.to_string(), + data_type: arrow_field.data_type().clone(), + last_timestamp: *ts, + }) + } + }) + .collect(); + + Ok(FieldList { fields }) + } +} + +/// Merge several FieldLists into a single field list, merging the +/// entries appropriately +// Clippy gets confused and tells me that I should be using Self +// instead of Vec even though the type of Vec being created is different +#[allow(clippy::use_self)] +impl IntoFieldList for Vec { + fn into_fieldlist(self) -> Result { + if self.is_empty() { + return Ok(FieldList::default()); + } + + // otherwise merge the fields together + let mut field_map = BTreeMap::::new(); + + // iterate over all fields + let field_iter = self.into_iter().flat_map(|f| f.fields.into_iter()); + + for new_field in field_iter { + if let Some(existing_field) = field_map.get_mut(&new_field.name) { + ensure!( + existing_field.data_type == new_field.data_type, + InconsistentFieldTypeSnafu { + field_name: new_field.name, + data_type1: existing_field.data_type.clone(), + data_type2: new_field.data_type, + } + ); + existing_field.last_timestamp = + std::cmp::max(existing_field.last_timestamp, new_field.last_timestamp); + } + // no entry for field yet + else { + field_map.insert(new_field.name.clone(), new_field); + } + } + + let mut fields = field_map.into_values().collect::>(); + fields.sort_by(|a, b| a.name.cmp(&b.name)); + + Ok(FieldList { fields }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + use arrow::array::ArrayRef; + use arrow::{ + array::{Int64Array, StringArray}, + datatypes::{DataType as ArrowDataType, Field as ArrowField, Schema}, + }; + use schema::{TIME_DATA_TIMEZONE, TIME_DATA_TYPE}; + + #[test] + fn test_convert_single_batch() { + let schema = Arc::new(Schema::new(vec![ + ArrowField::new("string_field", ArrowDataType::Utf8, true), + ArrowField::new("time", TIME_DATA_TYPE(), true), + ])); + + let string_array: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "foo"])); + let timestamp_array: ArrayRef = Arc::new( + TimestampNanosecondArray::from_iter_values(vec![1000, 2000, 3000, 4000]) + .with_timezone_opt(TIME_DATA_TIMEZONE()), + ); + + let actual = do_conversion( + Arc::clone(&schema), + vec![vec![string_array, timestamp_array]], + ) + .expect("convert correctly"); + + let expected = FieldList { + fields: vec![Field { + name: "string_field".into(), + data_type: ArrowDataType::Utf8, + last_timestamp: 4000, + }], + }; + + assert_eq!( + expected, actual, + "Expected:\n{expected:#?}\nActual:\n{actual:#?}" + ); + + // expect same even if the timestamp order is different + + let string_array: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "foo"])); + let timestamp_array: ArrayRef = Arc::new( + TimestampNanosecondArray::from_iter_values(vec![1000, 4000, 2000, 3000]) + .with_timezone_opt(TIME_DATA_TIMEZONE()), + ); + + let actual = do_conversion(schema, vec![vec![string_array, timestamp_array]]) + .expect("convert correctly"); + + assert_eq!( + expected, actual, + "Expected:\n{expected:#?}\nActual:\n{actual:#?}" + ); + } + + #[test] + fn test_convert_two_batches() { + let schema = Arc::new(Schema::new(vec![ + ArrowField::new("string_field", ArrowDataType::Utf8, true), + ArrowField::new("time", TIME_DATA_TYPE(), true), + ])); + + let string_array1: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"])); + let timestamp_array1: ArrayRef = Arc::new( + TimestampNanosecondArray::from_iter_values(vec![1000, 3000]) + .with_timezone_opt(TIME_DATA_TIMEZONE()), + ); + + let string_array2: ArrayRef = Arc::new(StringArray::from(vec!["foo", "foo"])); + let timestamp_array2: ArrayRef = Arc::new( + TimestampNanosecondArray::from_iter_values(vec![1000, 4000]) + .with_timezone_opt(TIME_DATA_TIMEZONE()), + ); + + let actual = do_conversion( + schema, + vec![ + vec![string_array1, timestamp_array1], + vec![string_array2, timestamp_array2], + ], + ) + .expect("convert correctly"); + + let expected = FieldList { + fields: vec![Field { + name: "string_field".into(), + data_type: ArrowDataType::Utf8, + last_timestamp: 4000, + }], + }; + + assert_eq!( + expected, actual, + "Expected:\n{expected:#?}\nActual:\n{actual:#?}" + ); + } + + #[test] + fn test_convert_all_nulls() { + let schema = Arc::new(Schema::new(vec![ + ArrowField::new("string_field", ArrowDataType::Utf8, true), + ArrowField::new("time", TIME_DATA_TYPE(), true), + ])); + + // string array has no actual values, so should not be returned as a field + let string_array: ArrayRef = + Arc::new(StringArray::from(vec![None::<&str>, None, None, None])); + let timestamp_array: ArrayRef = Arc::new( + TimestampNanosecondArray::from_iter_values(vec![1000, 2000, 3000, 4000]) + .with_timezone_opt(TIME_DATA_TIMEZONE()), + ); + + let actual = do_conversion(schema, vec![vec![string_array, timestamp_array]]) + .expect("convert correctly"); + + let expected = FieldList { fields: vec![] }; + + assert_eq!( + expected, actual, + "Expected:\n{expected:#?}\nActual:\n{actual:#?}" + ); + } + + // test three columns, with different data types and null + #[test] + fn test_multi_column_multi_datatype() { + let schema = Arc::new(Schema::new(vec![ + ArrowField::new("string_field", ArrowDataType::Utf8, true), + ArrowField::new("int_field", ArrowDataType::Int64, true), + ArrowField::new("time", TIME_DATA_TYPE(), true), + ])); + + let string_array: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "foo"])); + let int_array: ArrayRef = + Arc::new(Int64Array::from(vec![Some(10), Some(20), Some(30), None])); + let timestamp_array: ArrayRef = Arc::new( + TimestampNanosecondArray::from_iter_values(vec![1000, 2000, 3000, 4000]) + .with_timezone_opt(TIME_DATA_TIMEZONE()), + ); + + let expected = FieldList { + fields: vec![ + Field { + name: "string_field".into(), + data_type: ArrowDataType::Utf8, + last_timestamp: 4000, + }, + Field { + name: "int_field".into(), + data_type: ArrowDataType::Int64, + last_timestamp: 3000, + }, + ], + }; + + let actual = do_conversion(schema, vec![vec![string_array, int_array, timestamp_array]]) + .expect("conversion successful"); + + assert_eq!( + expected, actual, + "Expected:\n{expected:#?}\nActual:\n{actual:#?}" + ); + } + + fn do_conversion(schema: SchemaRef, value_arrays: Vec>) -> Result { + let batches = value_arrays + .into_iter() + .map(|arrays| { + RecordBatch::try_new(Arc::clone(&schema), arrays).expect("created new record batch") + }) + .collect::>(); + + batches.into_fieldlist() + } + + #[test] + fn test_merge_field_list() { + let field1 = Field { + name: "one".into(), + data_type: ArrowDataType::Utf8, + last_timestamp: 4000, + }; + let field2 = Field { + name: "two".into(), + data_type: ArrowDataType::Int64, + last_timestamp: 3000, + }; + + let l1 = FieldList { + fields: vec![field1, field2.clone()], + }; + let actual = vec![l1.clone()].into_fieldlist().unwrap(); + let expected = l1.clone(); + + assert_eq!( + expected, actual, + "Expected:\n{expected:#?}\nActual:\n{actual:#?}" + ); + + let field1_later = Field { + name: "one".into(), + data_type: ArrowDataType::Utf8, + last_timestamp: 5000, + }; + + // use something that has a later timestamp and expect the later one takes + // precedence + let l2 = FieldList { + fields: vec![field1_later.clone()], + }; + let actual = vec![l1.clone(), l2.clone()].into_fieldlist().unwrap(); + let expected = FieldList { + fields: vec![field1_later, field2], + }; + + assert_eq!( + expected, actual, + "Expected:\n{expected:#?}\nActual:\n{actual:#?}" + ); + + // Now, try to add a field that has a different type + + let field1_new_type = Field { + name: "one".into(), + data_type: ArrowDataType::Int64, + last_timestamp: 5000, + }; + + // use something that has a later timestamp and expect the later one takes + // precedence + let l3 = FieldList { + fields: vec![field1_new_type], + }; + let actual = vec![l1, l2, l3].into_fieldlist(); + let actual_error = actual.expect_err("should be an error").to_string(); + + let expected_error = + "Inconsistent data type for field 'one': found both 'Utf8' and 'Int64'"; + + assert!( + actual_error.contains(expected_error), + "Can not find expected '{expected_error}' in actual '{actual_error}'" + ); + } +} diff --git a/iox_query/src/exec/gapfill/algo.rs b/iox_query/src/exec/gapfill/algo.rs new file mode 100644 index 0000000..0733038 --- /dev/null +++ b/iox_query/src/exec/gapfill/algo.rs @@ -0,0 +1,1650 @@ +//! Contains the [GapFiller] type which does the +//! actual gap filling of record batches. + +mod interpolate; + +use std::{ops::Range, sync::Arc}; + +use arrow::{ + array::{Array, ArrayRef, TimestampNanosecondArray, UInt64Array}, + compute::{kernels::take, partition}, + datatypes::SchemaRef, + record_batch::RecordBatch, +}; +use datafusion::{ + error::{DataFusionError, Result}, + scalar::ScalarValue, +}; +use hashbrown::HashMap; + +use self::interpolate::Segment; + +use super::{params::GapFillParams, FillStrategy}; + +/// Provides methods to the [`GapFillStream`](super::stream::GapFillStream) +/// module that fill gaps in buffered input. +/// +/// [GapFiller] assumes that there will be at least `output_batch_size + 2` +/// input records buffered when [`build_gapfilled_output`](GapFiller::build_gapfilled_output) +/// is invoked, provided there is enough data. +/// +/// Once output is produced, clients should call `slice_input_batch` to unbuffer +/// data that is no longer needed. +/// +/// Below is a diagram of how buffered input is structured. +/// +/// ```text +/// +/// BUFFERED INPUT ROWS +/// +/// time group columns aggregate columns +/// ╓────╥───┬───┬─────────────╥───┬───┬─────────────╖ +/// context row 0 ║ ║ │ │ . . . ║ │ │ . . . ║ +/// ╟────╫───┼───┼─────────────╫───┼───┼─────────────╢ +/// ┬──── cursor────► 1 ║ ║ │ │ ║ │ │ ║ +/// │ ╟────╫───┼───┼─────────────╫───┼───┼─────────────╢ +/// │ 2 ║ ║ │ │ ║ │ │ ║ +/// │ ╟────╫───┼───┼─────────────╫───┼───┼─────────────╢ +/// │ . . . +/// output_batch_size . . . +/// │ . . . +/// │ ╟────╫───┼───┼─────────────╫───┼───┼─────────────╢ +/// │ n - 1 ║ ║ │ │ ║ │ │ ║ +/// │ ╟────╫───┼───┼─────────────╫───┼───┼─────────────╢ +/// ┴──── n ║ ║ │ │ ║ │ │ ║ +/// ╟────╫───┼───┼─────────────╫───┼───┼─────────────╢ +/// trailing row(s) n + 1 ║ ║ │ │ ║ │ │ ║ +/// ╟────╫───┼───┼─────────────╫───┼───┼─────────────╢ +/// . . . +/// . . . +/// . . . +/// ``` +/// +/// Just before generating output, the cursor will generally point at offset 1 +/// in the input, since offset 0 is a _context row_. The exception to this is +/// there is no context row when generating the first output batch. +/// +/// Buffering at least `output_batch_size + 2` rows ensures that: +/// - `GapFiller` can produce enough rows to produce a complete output batch, since +/// every input row will appear in the output. +/// - There is a _context row_ that represents the last input row that got output before +/// the current output batch. Group column values will be taken from this row +/// (using the [`take`](take::take) kernel) when we are generating trailing gaps, i.e., +/// when all of the input rows have been output for a series in the previous batch, +/// but there still remains missing rows to produce at the end. +/// - Having at least one additional _trailing row_ at the end ensures that `GapFiller` can +/// infer whether there is trailing gaps to produce at the beginning of the +/// next batch, since it can discover if the last row starts a new series. +/// - If there are columns that have a fill strategy of [`LinearInterpolate`], then more +/// trailing rows may be necessary to find the next non-null value for the column. +/// +/// [`LinearInterpolate`]: FillStrategy::LinearInterpolate +#[derive(Debug)] +pub(super) struct GapFiller { + /// The static parameters of gap-filling: time range start, end and the stride. + params: GapFillParams, + /// The number of rows to produce in each output batch. + batch_size: usize, + /// The current state of gap-filling, including the next timestamp, + /// the offset of the next input row, and remaining space in output batch. + cursor: Cursor, +} + +impl GapFiller { + /// Initialize a [GapFiller] at the beginning of an input record batch. + pub fn new(params: GapFillParams, batch_size: usize) -> Self { + let cursor = Cursor::new(¶ms); + Self { + params, + batch_size, + cursor, + } + } + + /// Given that the cursor points at the input row that will be + /// the first row in the next output batch, return the offset + /// of last input row that could possibly be in the output. + /// + /// This offset is used by ['BufferedInput`] to determine how many + /// rows need to be buffered. + /// + /// [`BufferedInput`]: super::BufferedInput + pub(super) fn last_output_row_offset(&self) -> usize { + self.cursor.next_input_offset + self.batch_size - 1 + } + + /// Returns true if there are no more output rows to produce given + /// the number of rows of buffered input. + pub fn done(&self, buffered_input_row_count: usize) -> bool { + self.cursor.done(buffered_input_row_count) + } + + /// Produces a gap-filled output [RecordBatch]. + /// + /// Input arrays are represented as pairs that include their offset in the + /// schema at member `0`. + pub fn build_gapfilled_output( + &mut self, + schema: SchemaRef, + input_time_array: (usize, &TimestampNanosecondArray), + group_arrays: &[(usize, ArrayRef)], + aggr_arrays: &[(usize, ArrayRef)], + ) -> Result { + let series_ends = self.plan_output_batch(input_time_array.1, group_arrays)?; + self.cursor.remaining_output_batch_size = self.batch_size; + self.build_output( + schema, + input_time_array, + group_arrays, + aggr_arrays, + &series_ends, + ) + } + + /// Slice the input batch so that it has one context row before the next input offset. + pub fn slice_input_batch(&mut self, batch: RecordBatch) -> Result { + if self.cursor.next_input_offset < 2 { + // nothing to do + return Ok(batch); + } + + let offset = self.cursor.next_input_offset - 1; + self.cursor.slice(offset, &batch)?; + + let len = batch.num_rows() - offset; + Ok(batch.slice(offset, len)) + } + + /// Produces a vector of offsets that are the exclusive ends of each series + /// in the buffered input. It will return the ends of only those series + /// that can at least be started in the output batch. + /// + /// Uses [`lexicographical_partition_ranges`](arrow::compute::lexicographical_partition_ranges) + /// to partition input rows into series. + fn plan_output_batch( + &mut self, + input_time_array: &TimestampNanosecondArray, + group_arr: &[(usize, ArrayRef)], + ) -> Result> { + if group_arr.is_empty() { + // there are no group columns, so the output + // will be just one big series. + return Ok(vec![input_time_array.len()]); + } + + let sort_columns = group_arr + .iter() + .map(|(_, arr)| Arc::clone(arr)) + .collect::>(); + + let mut ranges = partition(&sort_columns)?.ranges().into_iter(); + + let mut series_ends = vec![]; + let mut cursor = self.cursor.clone_for_aggr_col(None)?; + let mut output_row_count = 0; + + let start_offset = cursor.next_input_offset; + assert!(start_offset <= 1, "input is sliced after it is consumed"); + while output_row_count < self.batch_size { + match ranges.next() { + Some(Range { end, .. }) => { + assert!( + end > 0, + "each lexicographical partition will have at least one row" + ); + + if let Some(nrows) = + cursor.count_series_rows(&self.params, input_time_array, end) + { + output_row_count += nrows; + series_ends.push(end); + } + } + None => break, + } + } + + Ok(series_ends) + } + + /// Helper method that produces gap-filled record batches. + /// + /// This method works by producing each array in the output completely, + /// for all series that have end offsets in `series_ends`, before producing + /// subsequent arrays. + fn build_output( + &mut self, + schema: SchemaRef, + input_time_array: (usize, &TimestampNanosecondArray), + group_arr: &[(usize, ArrayRef)], + aggr_arr: &[(usize, ArrayRef)], + series_ends: &[usize], + ) -> Result { + let mut output_arrays: Vec<(usize, ArrayRef)> = + Vec::with_capacity(group_arr.len() + aggr_arr.len() + 1); // plus one for time column + + // build the time column + let mut cursor = self.cursor.clone_for_aggr_col(None)?; + let (time_idx, input_time_array) = input_time_array; + let time_vec = cursor.build_time_vec(&self.params, series_ends, input_time_array)?; + let output_time_len = time_vec.len(); + output_arrays.push(( + time_idx, + Arc::new( + TimestampNanosecondArray::from(time_vec) + .with_timezone_opt(input_time_array.timezone()), + ), + )); + // There may not be any aggregate or group columns, so use this cursor state as the new + // GapFiller cursor once this output batch is complete. + let mut final_cursor = cursor; + + // build the other group columns + for (idx, ga) in group_arr { + let mut cursor = self.cursor.clone_for_aggr_col(None)?; + let take_vec = + cursor.build_group_take_vec(&self.params, series_ends, input_time_array)?; + if take_vec.len() != output_time_len { + return Err(DataFusionError::Internal(format!( + "gapfill group column has {} rows, expected {}", + take_vec.len(), + output_time_len + ))); + } + let take_arr = UInt64Array::from(take_vec); + output_arrays.push((*idx, take::take(ga, &take_arr, None)?)) + } + + // Build the aggregate columns + for (idx, aa) in aggr_arr { + let mut cursor = self.cursor.clone_for_aggr_col(Some(*idx))?; + let output_array = + cursor.build_aggr_col(&self.params, series_ends, input_time_array, aa)?; + if output_array.len() != output_time_len { + return Err(DataFusionError::Internal(format!( + "gapfill aggr column has {} rows, expected {}", + output_array.len(), + output_time_len + ))); + } + output_arrays.push((*idx, output_array)); + final_cursor.merge_aggr_col_cursor(cursor); + } + + output_arrays.sort_by(|(a, _), (b, _)| a.cmp(b)); + let output_arrays: Vec<_> = output_arrays.into_iter().map(|(_, arr)| arr).collect(); + let batch = RecordBatch::try_new(Arc::clone(&schema), output_arrays) + .map_err(|err| DataFusionError::ArrowError(err, None))?; + + self.cursor = final_cursor; + Ok(batch) + } +} + +/// Maintains the state needed to fill gaps in output columns. Also provides methods +/// for building vectors that build time, group, and aggregate output arrays. +#[derive(Debug)] +pub(crate) struct Cursor { + /// Where to read the next row from the input. + next_input_offset: usize, + /// The next timestamp to be produced for the current series. + /// Since the lower bound for gap filling could just be "whatever + /// the first timestamp in the series is," this may be `None` before + /// any rows with non-null timestamps are produced for a series. + next_ts: Option, + /// How many rows may be output before we need to start a new record batch. + remaining_output_batch_size: usize, + /// True if there are trailing gaps from after the last input row for a series + /// to be produced at the beginning of the next output batch. + trailing_gaps: bool, + /// State for each aggregate column, keyed on the columns offset in the schema. + aggr_col_states: HashMap, +} + +impl Cursor { + /// Creates a new cursor. + fn new(params: &GapFillParams) -> Self { + let aggr_col_states = params + .fill_strategy + .iter() + .map(|(idx, fs)| (*idx, AggrColState::new(fs))) + .collect(); + Self { + next_input_offset: 0, + next_ts: params.first_ts, + remaining_output_batch_size: 0, + trailing_gaps: false, + aggr_col_states, + } + } + + /// Returns true of we point past all rows of buffered input and there + /// are no trailing gaps left to produce. + fn done(&self, buffered_input_row_count: usize) -> bool { + self.next_input_offset == buffered_input_row_count && !self.trailing_gaps + } + + /// Make a clone of this cursor to be used for creating an aggregate column, + /// if `idx` is `Some`. The resulting `Cursor` will only contain [AggrColState] + /// for the indicated column. + /// + /// When `idx` is `None`, return a `Cursor` with an empty [Cursor::aggr_col_states]. + fn clone_for_aggr_col(&self, idx: Option) -> Result { + let mut cur = Self { + next_input_offset: self.next_input_offset, + next_ts: self.next_ts, + remaining_output_batch_size: self.remaining_output_batch_size, + trailing_gaps: self.trailing_gaps, + aggr_col_states: HashMap::default(), + }; + if let Some(idx) = idx { + let state = self + .aggr_col_states + .get(&idx) + .ok_or(DataFusionError::Internal(format!( + "could not find aggr col with offset {idx}" + )))?; + cur.aggr_col_states.insert(idx, state.clone()); + } + Ok(cur) + } + + /// Update [Cursor::aggr_col_states] with updated state for an + /// aggregate column. `cursor` will have been created via `Cursor::clone_for_aggr_col`, + /// so [Cursor::aggr_col_states] will contain exactly one item. + /// + /// # Panics + /// + /// Will panic if input cursor's [Cursor::aggr_col_states] does not contain exactly one item. + fn merge_aggr_col_cursor(&mut self, cursor: Self) { + assert_eq!(1, cursor.aggr_col_states.len()); + for (idx, state) in cursor.aggr_col_states.into_iter() { + self.aggr_col_states.insert(idx, state); + } + } + + /// Get the [AggrColState] for this cursor. `self` will have been created via + /// `Cursor::clone_for_aggr_col`, so [Cursor::aggr_col_states] will contain exactly one item. + /// + /// # Panics + /// + /// Will panic if [Cursor::aggr_col_states] does not contain exactly one item. + fn get_aggr_col_state(&self) -> &AggrColState { + assert_eq!(1, self.aggr_col_states.len()); + self.aggr_col_states.iter().next().unwrap().1 + } + + /// Set the [AggrColState] for this cursor. `self` will have been created via + /// `Cursor::clone_for_aggr_col`, so [Cursor::aggr_col_states] will contain exactly one item. + /// + /// # Panics + /// + /// Will panic if [Cursor::aggr_col_states] does not contain exactly one item. + fn set_aggr_col_state(&mut self, new_state: AggrColState) { + assert_eq!(1, self.aggr_col_states.len()); + let (_idx, state) = self.aggr_col_states.iter_mut().next().unwrap(); + *state = new_state; + } + + /// Counts the number of rows that will be produced for a series that ends (exclusively) + /// at `series_end`, including rows that have a null timestamp, if any. + /// + /// Produces `None` for the case where `next_input_offset` is equal to `series_end`, + /// and there are no trailing gaps to produce. + fn count_series_rows( + &mut self, + params: &GapFillParams, + input_time_array: &TimestampNanosecondArray, + series_end: usize, + ) -> Option { + if !self.trailing_gaps && self.next_input_offset == series_end { + return None; + } + + let mut count = if input_time_array.null_count() > 0 { + let len = series_end - self.next_input_offset; + let slice = input_time_array.slice(self.next_input_offset, len); + slice.null_count() + } else { + 0 + }; + + self.next_input_offset += count; + if self.maybe_init_next_ts(input_time_array, series_end) { + count += params.valid_row_count(self.next_ts.unwrap()); + } + + self.next_input_offset = series_end; + self.next_ts = params.first_ts; + + Some(count) + } + + /// Update this cursor to reflect that `offset` older rows are being sliced off from the + /// buffered input. + fn slice(&mut self, offset: usize, batch: &RecordBatch) -> Result<()> { + for (idx, aggr_col_state) in &mut self.aggr_col_states { + aggr_col_state.slice(offset, batch.column(*idx))?; + } + self.next_input_offset -= offset; + Ok(()) + } + + /// Attempts to assign a value to `self.next_ts` if it does not have one. + /// + /// This bit of abstraction is needed because the lower bound for gap filling may be + /// determined in one of two ways: + /// * If the [`GapFillParams`] provided by client code has `first_ts` set to `Some`, this + /// will be the first timestamp for each series. In this case `self.next_ts` + /// will never `None`, and this function does nothing. + /// * Otherwise it is determined to be whatever the first timestamp in the input series is. + /// In this case `params.first_ts == None`, and we need to extract the timestamp from + /// the input time array. + /// + /// Returns true if `self.next_ts` ends up containing a value. + fn maybe_init_next_ts( + &mut self, + input_time_array: &TimestampNanosecondArray, + series_end: usize, + ) -> bool { + self.next_ts = match self.next_ts { + Some(_) => self.next_ts, + None if self.next_input_offset < series_end + && input_time_array.is_valid(self.next_input_offset) => + { + Some(input_time_array.value(self.next_input_offset)) + } + // This may happen if current input offset points at a row + // with a null timestamp, or is past the end of the current series. + _ => None, + }; + self.next_ts.is_some() + } + + /// Builds a vector that can be used to produce a timestamp array. + fn build_time_vec( + &mut self, + params: &GapFillParams, + series_ends: &[usize], + input_time_array: &TimestampNanosecondArray, + ) -> Result>> { + struct TimeBuilder { + times: Vec>, + } + + impl VecBuilder for TimeBuilder { + fn push(&mut self, row_status: RowStatus) -> Result<()> { + match row_status { + RowStatus::NullTimestamp { .. } => self.times.push(None), + RowStatus::Present { ts, .. } | RowStatus::Missing { ts, .. } => { + self.times.push(Some(ts)) + } + } + Ok(()) + } + } + + let mut time_builder = TimeBuilder { + times: Vec::with_capacity(self.remaining_output_batch_size), + }; + self.build_vec(params, input_time_array, series_ends, &mut time_builder)?; + + Ok(time_builder.times) + } + + /// Builds a vector that can use the [`take`](take::take) kernel + /// to produce a group column. + fn build_group_take_vec( + &mut self, + params: &GapFillParams, + series_ends: &[usize], + input_time_array: &TimestampNanosecondArray, + ) -> Result> { + struct GroupBuilder { + take_idxs: Vec, + } + + impl VecBuilder for GroupBuilder { + fn push(&mut self, row_status: RowStatus) -> Result<()> { + match row_status { + RowStatus::NullTimestamp { + series_end_offset, .. + } + | RowStatus::Present { + series_end_offset, .. + } + | RowStatus::Missing { + series_end_offset, .. + } => self.take_idxs.push(series_end_offset as u64 - 1), + } + Ok(()) + } + } + + let mut group_builder = GroupBuilder { + take_idxs: Vec::with_capacity(self.remaining_output_batch_size), + }; + self.build_vec(params, input_time_array, series_ends, &mut group_builder)?; + + Ok(group_builder.take_idxs) + } + + /// Produce a gap-filled array for the aggregate column + /// in [`Self::aggr_col_states`]. + /// + /// # Panics + /// + /// Will panic if [Cursor::aggr_col_states] does not contain exactly one item. + fn build_aggr_col( + &mut self, + params: &GapFillParams, + series_ends: &[usize], + input_time_array: &TimestampNanosecondArray, + input_aggr_array: &ArrayRef, + ) -> Result { + match self.get_aggr_col_state() { + AggrColState::Null => { + self.build_aggr_fill_null(params, series_ends, input_time_array, input_aggr_array) + } + AggrColState::PrevNullAsIntentional { .. } | AggrColState::PrevNullAsMissing { .. } => { + self.build_aggr_fill_prev(params, series_ends, input_time_array, input_aggr_array) + } + AggrColState::PrevNullAsMissingStashed { .. } => self.build_aggr_fill_prev_stashed( + params, + series_ends, + input_time_array, + input_aggr_array, + ), + AggrColState::LinearInterpolate(_) => self.build_aggr_fill_interpolate( + params, + series_ends, + input_time_array, + input_aggr_array, + ), + } + } + + /// Builds an array using the [`take`](take::take) kernel + /// to produce an aggregate output column, filling gaps with + /// null values. + fn build_aggr_fill_null( + &mut self, + params: &GapFillParams, + series_ends: &[usize], + input_time_array: &TimestampNanosecondArray, + input_aggr_array: &ArrayRef, + ) -> Result { + struct AggrBuilder { + take_idxs: Vec>, + } + + impl VecBuilder for AggrBuilder { + fn push(&mut self, row_status: RowStatus) -> Result<()> { + match row_status { + RowStatus::NullTimestamp { offset, .. } | RowStatus::Present { offset, .. } => { + self.take_idxs.push(Some(offset as u64)) + } + RowStatus::Missing { .. } => self.take_idxs.push(None), + } + Ok(()) + } + } + + let mut aggr_builder = AggrBuilder { + take_idxs: Vec::with_capacity(self.remaining_output_batch_size), + }; + self.build_vec(params, input_time_array, series_ends, &mut aggr_builder)?; + + let take_arr = UInt64Array::from(aggr_builder.take_idxs); + take::take(input_aggr_array, &take_arr, None) + .map_err(|err| DataFusionError::ArrowError(err, None)) + } + + /// Builds an array using the [`take`](take::take) kernel + /// to produce an aggregate output column, filling gaps with the + /// previous values in the column. + fn build_aggr_fill_prev( + &mut self, + params: &GapFillParams, + series_ends: &[usize], + input_time_array: &TimestampNanosecondArray, + input_aggr_array: &ArrayRef, + ) -> Result { + struct AggrBuilder<'a> { + take_idxs: Vec>, + prev_offset: Option, + input_aggr_array: &'a ArrayRef, + null_as_missing: bool, + } + + impl<'a> VecBuilder for AggrBuilder<'a> { + fn push(&mut self, row_status: RowStatus) -> Result<()> { + match row_status { + RowStatus::NullTimestamp { offset, .. } => { + self.take_idxs.push(Some(offset as u64)) + } + RowStatus::Present { offset, .. } => { + if !self.null_as_missing || self.input_aggr_array.is_valid(offset) { + self.take_idxs.push(Some(offset as u64)); + self.prev_offset = Some(offset as u64); + } else { + self.take_idxs.push(self.prev_offset); + } + } + RowStatus::Missing { .. } => self.take_idxs.push(self.prev_offset), + } + Ok(()) + } + fn start_new_series(&mut self) -> Result<()> { + self.prev_offset = None; + Ok(()) + } + } + + let null_as_missing = matches!( + self.get_aggr_col_state(), + AggrColState::PrevNullAsMissing { .. } + ); + + let mut aggr_builder = AggrBuilder { + take_idxs: Vec::with_capacity(self.remaining_output_batch_size), + prev_offset: self.get_aggr_col_state().prev_offset(), + input_aggr_array, + null_as_missing, + }; + self.build_vec(params, input_time_array, series_ends, &mut aggr_builder)?; + + let AggrBuilder { + take_idxs, + prev_offset, + .. + } = aggr_builder; + self.set_aggr_col_state(match null_as_missing { + false => AggrColState::PrevNullAsIntentional { + offset: prev_offset, + }, + true => AggrColState::PrevNullAsMissing { + offset: prev_offset, + }, + }); + + let take_arr = UInt64Array::from(take_idxs); + take::take(input_aggr_array, &take_arr, None) + .map_err(|err| DataFusionError::ArrowError(err, None)) + } + + /// Builds an array using the [`interleave`](arrow::compute::interleave) kernel + /// to produce an aggregate output column, filling gaps with the + /// previous values in the column. + fn build_aggr_fill_prev_stashed( + &mut self, + params: &GapFillParams, + series_ends: &[usize], + input_time_array: &TimestampNanosecondArray, + input_aggr_array: &ArrayRef, + ) -> Result { + let stash = self.get_aggr_col_state().stash(); + let mut aggr_builder = StashedAggrBuilder { + interleave_idxs: Vec::with_capacity(self.remaining_output_batch_size), + state: StashedAggrState::Stashed, + stash, + input_aggr_array, + }; + self.build_vec(params, input_time_array, series_ends, &mut aggr_builder)?; + let output_array = aggr_builder.build()?; + + // Update the aggregate column state for this cursor to prime it for the + // next batch. + let StashedAggrBuilder { state, .. } = aggr_builder; + match state { + StashedAggrState::Stashed => (), // nothing changes + StashedAggrState::PrevNone => { + self.set_aggr_col_state(AggrColState::PrevNullAsMissing { offset: None }) + } + StashedAggrState::PrevSome { offset } => { + self.set_aggr_col_state(AggrColState::PrevNullAsMissing { + offset: Some(offset as u64), + }) + } + }; + + Ok(output_array) + } + + /// Helper method that iterates over each series + /// that ends with offsets in `series_ends` and produces + /// the appropriate output values. + fn build_vec( + &mut self, + params: &GapFillParams, + input_time_array: &TimestampNanosecondArray, + series_ends: &[usize], + vec_builder: &mut impl VecBuilder, + ) -> Result<()> { + for series in series_ends { + if self + .next_ts + .map_or(false, |next_ts| next_ts > params.last_ts) + { + vec_builder.start_new_series()?; + self.next_ts = params.first_ts; + } + + self.append_series_items(params, input_time_array, *series, vec_builder)?; + } + + let last_series_end = series_ends.last().ok_or(DataFusionError::Internal( + "expected at least one item in series batch".to_string(), + ))?; + + self.trailing_gaps = self.next_input_offset == *last_series_end + && self + .next_ts + .map_or(true, |next_ts| next_ts <= params.last_ts); + Ok(()) + } + + /// Helper method that generates output for one series by invoking + /// [VecBuilder::push] for each output value in the column to be generated. + fn append_series_items( + &mut self, + params: &GapFillParams, + input_times: &TimestampNanosecondArray, + series_end: usize, + vec_builder: &mut impl VecBuilder, + ) -> Result<()> { + // If there are any null timestamps for this group, they will be first. + // These rows can just be copied into the output. + // Append the corresponding values. + while self.remaining_output_batch_size > 0 + && self.next_input_offset < series_end + && input_times.is_null(self.next_input_offset) + { + vec_builder.push(RowStatus::NullTimestamp { + series_end_offset: series_end, + offset: self.next_input_offset, + })?; + self.remaining_output_batch_size -= 1; + self.next_input_offset += 1; + } + + if !self.maybe_init_next_ts(input_times, series_end) { + return Ok(()); + } + let mut next_ts = self.next_ts.unwrap(); + + let output_row_count = std::cmp::min( + params.valid_row_count(next_ts), + self.remaining_output_batch_size, + ); + if output_row_count == 0 { + return Ok(()); + } + + // last_ts is the last timestamp that will fit in the output batch + let last_ts = next_ts + (output_row_count - 1) as i64 * params.stride; + + loop { + if self.next_input_offset >= series_end { + break; + } + let in_ts = input_times.value(self.next_input_offset); + if in_ts > last_ts { + break; + } + while next_ts < in_ts { + vec_builder.push(RowStatus::Missing { + series_end_offset: series_end, + ts: next_ts, + })?; + next_ts += params.stride; + } + vec_builder.push(RowStatus::Present { + series_end_offset: series_end, + offset: self.next_input_offset, + ts: next_ts, + })?; + next_ts += params.stride; + self.next_input_offset += 1; + } + + // Add any additional missing values after the last of the input. + while next_ts <= last_ts { + vec_builder.push(RowStatus::Missing { + series_end_offset: series_end, + ts: next_ts, + })?; + next_ts += params.stride; + } + + self.next_ts = Some(last_ts + params.stride); + self.remaining_output_batch_size -= output_row_count; + Ok(()) + } +} + +/// Maintains the state needed to fill gaps in an aggregate column, +/// depending on the fill strategy. +#[derive(Clone, Debug)] +enum AggrColState { + /// For [FillStrategy::Null] there is no state to maintain. + Null, + /// For [FillStrategy::PrevNullAsIntentional]. + PrevNullAsIntentional { offset: Option }, + /// For [FillStrategy::PrevNullAsMissing]. + PrevNullAsMissing { offset: Option }, + /// For [FillStrategy::PrevNullAsMissing], when + /// the fill value must be stashed in a separate array so it + /// can persist across output batches. + /// + /// This state happens when the previous value in the buffered input + /// rows has gone away during a call to [`GapFiller::slice_input_batch`]. + PrevNullAsMissingStashed { stash: ArrayRef }, + /// For [FillStrategy::LinearInterpolate], this tracks if we are in the middle + /// of a "segment" (two non-null points in the input separated by more + /// than the stride) between output batches. + LinearInterpolate(Option>), +} + +impl AggrColState { + /// Create a new [AggrColState] based on the [FillStrategy] for the column. + fn new(fill_strategy: &FillStrategy) -> Self { + match fill_strategy { + FillStrategy::Null => Self::Null, + FillStrategy::PrevNullAsIntentional => Self::PrevNullAsIntentional { offset: None }, + FillStrategy::PrevNullAsMissing => Self::PrevNullAsMissing { offset: None }, + FillStrategy::LinearInterpolate => Self::LinearInterpolate(None), + } + } + + /// Return the offset in the input from which to fill gaps. + /// + /// # Panics + /// + /// This method will panic if `self` is not [AggrColState::PrevNullAsIntentional] + /// or [AggrColState::PrevNullAsMissing]. + fn prev_offset(&self) -> Option { + match self { + Self::PrevNullAsIntentional { offset } | Self::PrevNullAsMissing { offset } => *offset, + _ => unreachable!(), + } + } + + /// Update state to reflect that older rows in the buffered input + /// are being sliced away. + fn slice(&mut self, offset: usize, array: &ArrayRef) -> Result<()> { + let offset = offset as u64; + match self { + Self::PrevNullAsMissing { offset: Some(v) } if offset > *v => { + // The element in the buffered input that may be in the output + // will be sliced away, so store it on the side. + let stash = StashedAggrBuilder::create_stash(array, *v)?; + *self = Self::PrevNullAsMissingStashed { stash }; + } + Self::PrevNullAsIntentional { offset: Some(v) } + | Self::PrevNullAsMissing { offset: Some(v) } => *v -= offset, + _ => (), + }; + Ok(()) + } + + /// Return the stashed previous value used to fill gaps. + /// + /// # Panics + /// + /// This method will panic if `self` is not [AggrColState::PrevNullAsMissingStashed]. + fn stash(&self) -> ArrayRef { + match self { + Self::PrevNullAsMissingStashed { stash } => Arc::clone(stash), + _ => unreachable!(), + } + } + + /// Return the segment being interpolated, if any. + /// + /// # Panics + /// + /// This method will panic if `self` is not [AggrColState::LinearInterpolate]. + fn segment(&self) -> &Option> { + match self { + Self::LinearInterpolate(segment) => segment, + _ => unreachable!(), + } + } +} + +/// A trait that lets implementors describe how to build the +/// vectors used to create Arrow arrays in the output. +trait VecBuilder { + /// Pushes a new value based on the output row's + /// relation to the input row. + fn push(&mut self, _: RowStatus) -> Result<()>; + + /// Called just before a new series starts. + fn start_new_series(&mut self) -> Result<()> { + Ok(()) + } +} + +/// The state of an input row relative to gap-filled output. +#[derive(Debug)] +enum RowStatus { + /// This row had a null timestamp in the input. + NullTimestamp { + /// The exclusive offset of the series end in the input. + series_end_offset: usize, + /// The offset of the null timestamp in the input time array. + offset: usize, + }, + /// A row with this timestamp is present in the input. + Present { + /// The exclusive offset of the series end in the input. + series_end_offset: usize, + /// The offset of the value in the input time array. + offset: usize, + /// The timestamp corresponding to this row. + ts: i64, + }, + /// A row with this timestamp is missing from the input. + Missing { + /// The exclusive offset of the series end in the input. + series_end_offset: usize, + /// The timestamp corresponding to this row. + ts: i64, + }, +} + +/// Implements [`VecBuilder`] for [`FillStrategy::PrevNullAsMissing`], +/// specifically for the case where a previous value that needs to be +/// propagated into a new output batch has been sliced off from +/// buffered input rows. +struct StashedAggrBuilder<'a> { + interleave_idxs: Vec<(usize, usize)>, + state: StashedAggrState, + stash: ArrayRef, + input_aggr_array: &'a ArrayRef, +} + +impl StashedAggrBuilder<'_> { + /// Create a 2-element array containing a null value and the value from + /// `input_aggr_array` at `offset` for use with the [`interleave`](arrow::compute::interleave) + /// kernel. + fn create_stash(input_aggr_array: &ArrayRef, offset: u64) -> Result { + let take_arr: UInt64Array = vec![None, Some(offset)].into(); + let stash = take::take(input_aggr_array, &take_arr, None) + .map_err(|err| DataFusionError::ArrowError(err, None))?; + Ok(stash) + } + + /// Build the output column. + fn build(&self) -> Result { + arrow::compute::interleave(&[&self.stash, self.input_aggr_array], &self.interleave_idxs) + .map_err(|err| DataFusionError::ArrowError(err, None)) + } + + fn buffered_input(offset: usize) -> (usize, usize) { + (Self::BUFFERED_INPUT_ARRAY, offset) + } + + const STASHED_NULL: (usize, usize) = (0, 0); + const STASHED_VALUE: (usize, usize) = (0, 1); + const BUFFERED_INPUT_ARRAY: usize = 1; +} + +/// Stores state about how to fill the output aggregate column +/// for [`StashedAggrBuilder`]. +enum StashedAggrState { + /// Fill the next missing or null element with the + /// stashed value. + Stashed, + /// Fill the next missing or null element with a null value. + PrevNone, + /// Fill the next missing or null element with the element in the + /// input at `offset`. + PrevSome { offset: usize }, +} + +impl<'a> VecBuilder for StashedAggrBuilder<'a> { + fn push(&mut self, row_status: RowStatus) -> Result<()> { + match row_status { + RowStatus::NullTimestamp { offset, .. } => { + self.interleave_idxs.push(Self::buffered_input(offset)); + self.state = StashedAggrState::PrevNone; + } + RowStatus::Present { offset, .. } if self.input_aggr_array.is_valid(offset) => { + self.interleave_idxs.push(Self::buffered_input(offset)); + self.state = StashedAggrState::PrevSome { offset }; + } + RowStatus::Present { .. } | RowStatus::Missing { .. } => match self.state { + StashedAggrState::Stashed => self.interleave_idxs.push(Self::STASHED_VALUE), + StashedAggrState::PrevNone => self.interleave_idxs.push(Self::STASHED_NULL), + StashedAggrState::PrevSome { offset } => { + self.interleave_idxs.push(Self::buffered_input(offset)) + } + }, + } + + Ok(()) + } + + fn start_new_series(&mut self) -> Result<()> { + self.state = StashedAggrState::PrevNone; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::{ + array::{ArrayRef, Float64Array, TimestampNanosecondArray}, + datatypes::{Field, Schema}, + record_batch::RecordBatch, + }; + use arrow_util::test_util::batches_to_lines; + use datafusion::error::Result; + use hashbrown::HashMap; + use schema::{InfluxColumnType, TIME_DATA_TIMEZONE}; + + use crate::exec::gapfill::{ + algo::{AggrColState, Cursor}, + params::GapFillParams, + FillStrategy, + }; + + #[test] + fn test_cursor_append_time_values() -> Result<()> { + test_helpers::maybe_start_logging(); + let input_times = TimestampNanosecondArray::from(vec![1000, 1100, 1200]); + let series = input_times.len(); + + let params = GapFillParams { + stride: 50, + first_ts: Some(950), + last_ts: 1250, + fill_strategy: simple_fill_strategy(), + }; + + let output_batch_size = 10000; + let mut cursor = new_cursor_with_batch_size(¶ms, output_batch_size); + + let out_times = cursor.build_time_vec(¶ms, &[series], &input_times)?; + assert_eq!( + vec![ + Some(950), + Some(1000), + Some(1050), + Some(1100), + Some(1150), + Some(1200), + Some(1250) + ], + out_times + ); + + assert_cursor_end_state(&cursor, &input_times, ¶ms); + + Ok(()) + } + + #[test] + fn test_cursor_append_time_values_no_first_ts() { + test_helpers::maybe_start_logging(); + let input_times = TimestampNanosecondArray::from(vec![1100, 1200]); + let series = input_times.len(); + + let params = GapFillParams { + stride: 50, + first_ts: None, + last_ts: 1250, + fill_strategy: simple_fill_strategy(), + }; + + let output_batch_size = 10000; + let mut cursor = new_cursor_with_batch_size(¶ms, output_batch_size); + + let out_times = cursor + .build_time_vec(¶ms, &[series], &input_times) + .unwrap(); + assert_eq!( + vec![Some(1100), Some(1150), Some(1200), Some(1250)], + out_times + ); + + assert_cursor_end_state(&cursor, &input_times, ¶ms); + } + + #[test] + fn test_cursor_append_time_value_nulls() -> Result<()> { + test_helpers::maybe_start_logging(); + let input_times = + TimestampNanosecondArray::from(vec![None, None, Some(1000), Some(1100), Some(1200)]); + let series = input_times.len(); + + let params = GapFillParams { + stride: 50, + first_ts: Some(950), + last_ts: 1250, + fill_strategy: simple_fill_strategy(), + }; + + let output_batch_size = 10000; + let mut cursor = new_cursor_with_batch_size(¶ms, output_batch_size); + let out_times = cursor.build_time_vec(¶ms, &[series], &input_times)?; + assert_eq!( + vec![ + None, + None, + Some(950), + Some(1000), + Some(1050), + Some(1100), + Some(1150), + Some(1200), + Some(1250) + ], + out_times + ); + + assert_cursor_end_state(&cursor, &input_times, ¶ms); + + Ok(()) + } + + #[test] + fn test_cursor_append_group_take() -> Result<()> { + let input_times = TimestampNanosecondArray::from(vec![1000, 1100, 1200]); + let series = input_times.len(); + + let params = GapFillParams { + stride: 50, + first_ts: Some(950), + last_ts: 1250, + fill_strategy: simple_fill_strategy(), + }; + + let output_batch_size = 10000; + let mut cursor = new_cursor_with_batch_size(¶ms, output_batch_size); + let take_idxs = cursor.build_group_take_vec(¶ms, &[series], &input_times)?; + assert_eq!(vec![2; 7], take_idxs); + + assert_cursor_end_state(&cursor, &input_times, ¶ms); + + Ok(()) + } + + #[test] + fn test_cursor_append_aggr_take() { + let input_times = TimestampNanosecondArray::from(vec![1000, 1100, 1200]); + let input_aggr_array: ArrayRef = Arc::new(Float64Array::from(vec![10.0, 11.0, 12.0])); + let series = input_times.len(); + + let params = GapFillParams { + stride: 50, + first_ts: Some(950), + last_ts: 1250, + fill_strategy: simple_fill_strategy(), + }; + + let output_batch_size = 10000; + let mut cursor = new_cursor_with_batch_size(¶ms, output_batch_size); + + let time_arr = TimestampNanosecondArray::from( + cursor + .clone_for_aggr_col(None) + .unwrap() + .build_time_vec(¶ms, &[series], &input_times) + .unwrap(), + ) + .with_timezone_opt(TIME_DATA_TIMEZONE()); + let arr = cursor + .build_aggr_fill_null(¶ms, &[series], &input_times, &input_aggr_array) + .unwrap(); + insta::assert_yaml_snapshot!(array_to_lines(&time_arr, &arr), @r###" + --- + - +--------------------------------+------+ + - "| time | a0 |" + - +--------------------------------+------+ + - "| 1970-01-01T00:00:00.000000950Z | |" + - "| 1970-01-01T00:00:00.000001Z | 10.0 |" + - "| 1970-01-01T00:00:00.000001050Z | |" + - "| 1970-01-01T00:00:00.000001100Z | 11.0 |" + - "| 1970-01-01T00:00:00.000001150Z | |" + - "| 1970-01-01T00:00:00.000001200Z | 12.0 |" + - "| 1970-01-01T00:00:00.000001250Z | |" + - +--------------------------------+------+ + "###); + + assert_cursor_end_state(&cursor, &input_times, ¶ms); + } + + #[test] + fn test_cursor_append_aggr_take_nulls() -> Result<()> { + test_helpers::maybe_start_logging(); + let input_times = + TimestampNanosecondArray::from(vec![None, None, Some(1000), Some(1100), Some(1200)]); + let input_aggr_array: ArrayRef = + Arc::new(Float64Array::from(vec![0.1, 0.2, 10.0, 11.0, 12.0])); + let series = input_times.len(); + + let params = GapFillParams { + stride: 50, + first_ts: Some(950), + last_ts: 1250, + fill_strategy: simple_fill_strategy(), + }; + + let output_batch_size = 10000; + let mut cursor = new_cursor_with_batch_size(¶ms, output_batch_size); + + let time_arr = TimestampNanosecondArray::from( + cursor + .clone_for_aggr_col(None) + .unwrap() + .build_time_vec(¶ms, &[series], &input_times) + .unwrap(), + ) + .with_timezone_opt(TIME_DATA_TIMEZONE()); + let arr = + cursor.build_aggr_fill_null(¶ms, &[series], &input_times, &input_aggr_array)?; + insta::assert_yaml_snapshot!(array_to_lines(&time_arr, &arr), @r###" + --- + - +--------------------------------+------+ + - "| time | a0 |" + - +--------------------------------+------+ + - "| | 0.1 |" + - "| | 0.2 |" + - "| 1970-01-01T00:00:00.000000950Z | |" + - "| 1970-01-01T00:00:00.000001Z | 10.0 |" + - "| 1970-01-01T00:00:00.000001050Z | |" + - "| 1970-01-01T00:00:00.000001100Z | 11.0 |" + - "| 1970-01-01T00:00:00.000001150Z | |" + - "| 1970-01-01T00:00:00.000001200Z | 12.0 |" + - "| 1970-01-01T00:00:00.000001250Z | |" + - +--------------------------------+------+ + "###); + + assert_cursor_end_state(&cursor, &input_times, ¶ms); + + Ok(()) + } + + #[test] + fn test_cursor_append_aggr_take_prev() { + let input_times = TimestampNanosecondArray::from(vec![ + // 950 + 1000, // 1050 + 1100, // 1150 + 1200, + // 1250 + ]); + let input_aggr_array: ArrayRef = Arc::new(Float64Array::from(vec![10.0, 11.0, 12.0])); + let series = input_times.len(); + + let idx = 0; + let params = GapFillParams { + stride: 50, + first_ts: Some(950), + last_ts: 1250, + fill_strategy: prev_fill_strategy(idx), + }; + + let output_batch_size = 10000; + let mut cursor = new_cursor_with_batch_size(¶ms, output_batch_size); + + let time_arr = TimestampNanosecondArray::from( + cursor + .clone_for_aggr_col(None) + .unwrap() + .build_time_vec(¶ms, &[series], &input_times) + .unwrap(), + ) + .with_timezone_opt(TIME_DATA_TIMEZONE()); + let arr = cursor + .build_aggr_fill_prev(¶ms, &[series], &input_times, &input_aggr_array) + .unwrap(); + insta::assert_yaml_snapshot!(array_to_lines(&time_arr, &arr), @r###" + --- + - +--------------------------------+------+ + - "| time | a0 |" + - +--------------------------------+------+ + - "| 1970-01-01T00:00:00.000000950Z | |" + - "| 1970-01-01T00:00:00.000001Z | 10.0 |" + - "| 1970-01-01T00:00:00.000001050Z | 10.0 |" + - "| 1970-01-01T00:00:00.000001100Z | 11.0 |" + - "| 1970-01-01T00:00:00.000001150Z | 11.0 |" + - "| 1970-01-01T00:00:00.000001200Z | 12.0 |" + - "| 1970-01-01T00:00:00.000001250Z | 12.0 |" + - +--------------------------------+------+ + "###); + + assert_cursor_end_state(&cursor, &input_times, ¶ms); + } + + #[test] + fn test_cursor_append_aggr_take_prev_with_nulls() { + let input_times = TimestampNanosecondArray::from(vec![ + None, + None, + // 950, + Some(1000), + // 1050 + Some(1100), + // 1150 + Some(1200), + // 1250 + // + ]); + let input_aggr_array: ArrayRef = + Arc::new(Float64Array::from(vec![0.0, 0.1, 10.0, 11.0, 12.0])); + let series = input_times.len(); + + let idx = 0; + let params = GapFillParams { + stride: 50, + first_ts: Some(950), + last_ts: 1250, + fill_strategy: prev_fill_strategy(idx), + }; + + let output_batch_size = 10000; + let mut cursor = new_cursor_with_batch_size(¶ms, output_batch_size); + + let time_arr = TimestampNanosecondArray::from( + cursor + .clone_for_aggr_col(None) + .unwrap() + .build_time_vec(¶ms, &[series], &input_times) + .unwrap(), + ) + .with_timezone_opt(TIME_DATA_TIMEZONE()); + let arr = cursor + .build_aggr_fill_prev(¶ms, &[series], &input_times, &input_aggr_array) + .unwrap(); + insta::assert_yaml_snapshot!(array_to_lines(&time_arr, &arr), @r###" + --- + - +--------------------------------+------+ + - "| time | a0 |" + - +--------------------------------+------+ + - "| | 0.0 |" + - "| | 0.1 |" + - "| 1970-01-01T00:00:00.000000950Z | |" + - "| 1970-01-01T00:00:00.000001Z | 10.0 |" + - "| 1970-01-01T00:00:00.000001050Z | 10.0 |" + - "| 1970-01-01T00:00:00.000001100Z | 11.0 |" + - "| 1970-01-01T00:00:00.000001150Z | 11.0 |" + - "| 1970-01-01T00:00:00.000001200Z | 12.0 |" + - "| 1970-01-01T00:00:00.000001250Z | 12.0 |" + - +--------------------------------+------+ + "###); + + assert_cursor_end_state(&cursor, &input_times, ¶ms); + } + + #[test] + fn test_cursor_append_aggr_take_prev_multi_series() { + let input_times = TimestampNanosecondArray::from(vec![ + // 950 + // 1000 + Some(1050), + // 1100 + // --- new series + // 950 + // 1000 + Some(1050), + // 1100 + ]) + .with_timezone_opt(TIME_DATA_TIMEZONE()); + let input_aggr_array: ArrayRef = Arc::new(Float64Array::from(vec![10.0, 11.0])); + let series_ends = vec![1, 2]; + + let idx = 0; + let params = GapFillParams { + stride: 50, + first_ts: Some(950), + last_ts: 1100, + fill_strategy: prev_fill_strategy(idx), + }; + + let output_batch_size = 10000; + let mut cursor = new_cursor_with_batch_size(¶ms, output_batch_size); + + let time_arr = TimestampNanosecondArray::from( + cursor + .clone_for_aggr_col(None) + .unwrap() + .build_time_vec(¶ms, &series_ends, &input_times) + .unwrap(), + ) + .with_timezone_opt(TIME_DATA_TIMEZONE()); + let arr = cursor + .build_aggr_fill_null(¶ms, &series_ends, &input_times, &input_aggr_array) + .unwrap(); + insta::assert_yaml_snapshot!(array_to_lines(&time_arr, &arr), @r###" + --- + - +--------------------------------+------+ + - "| time | a0 |" + - +--------------------------------+------+ + - "| 1970-01-01T00:00:00.000000950Z | |" + - "| 1970-01-01T00:00:00.000001Z | |" + - "| 1970-01-01T00:00:00.000001050Z | 10.0 |" + - "| 1970-01-01T00:00:00.000001100Z | |" + - "| 1970-01-01T00:00:00.000000950Z | |" + - "| 1970-01-01T00:00:00.000001Z | |" + - "| 1970-01-01T00:00:00.000001050Z | 11.0 |" + - "| 1970-01-01T00:00:00.000001100Z | |" + - +--------------------------------+------+ + "###); + + assert_cursor_end_state(&cursor, &input_times, ¶ms); + } + + #[test] + fn test_cursor_aggr_prev_null_as_missing() { + let input_times = TimestampNanosecondArray::from(vec![ + // 950 + // 1000 + Some(1050), + Some(1100), + // --- new series + Some(950), + Some(1000), + Some(1050), + Some(1100), + ]) + .with_timezone_opt(TIME_DATA_TIMEZONE()); + let input_aggr_array: ArrayRef = Arc::new(Float64Array::from(vec![ + // 950 + // 1000 + Some(10.0), // 1050 + None, // 1100 + Some(20.0), // 950 + None, // 1000 + Some(21.0), // 1050 + None, // 1100 + ])); + let series_ends = vec![2, 6]; + + let idx = 0; + let params = GapFillParams { + stride: 50, + first_ts: Some(950), + last_ts: 1100, + fill_strategy: prev_null_as_missing_fill_strategy(idx), + }; + + let output_batch_size = 10000; + let mut cursor = new_cursor_with_batch_size(¶ms, output_batch_size); + + let time_arr = TimestampNanosecondArray::from( + cursor + .clone_for_aggr_col(None) + .unwrap() + .build_time_vec(¶ms, &series_ends, &input_times) + .unwrap(), + ) + .with_timezone_opt(TIME_DATA_TIMEZONE()); + let arr = cursor + .build_aggr_fill_prev(¶ms, &series_ends, &input_times, &input_aggr_array) + .unwrap(); + insta::assert_yaml_snapshot!(array_to_lines(&time_arr, &arr), @r###" + --- + - +--------------------------------+------+ + - "| time | a0 |" + - +--------------------------------+------+ + - "| 1970-01-01T00:00:00.000000950Z | |" + - "| 1970-01-01T00:00:00.000001Z | |" + - "| 1970-01-01T00:00:00.000001050Z | 10.0 |" + - "| 1970-01-01T00:00:00.000001100Z | 10.0 |" + - "| 1970-01-01T00:00:00.000000950Z | 20.0 |" + - "| 1970-01-01T00:00:00.000001Z | 20.0 |" + - "| 1970-01-01T00:00:00.000001050Z | 21.0 |" + - "| 1970-01-01T00:00:00.000001100Z | 21.0 |" + - +--------------------------------+------+ + "###); + + assert_cursor_end_state(&cursor, &input_times, ¶ms); + } + + #[test] + fn test_cursor_aggr_prev_null_as_missing_stashed() { + // This test is intended to simulate producing output with + // prev-null-as-missing when the previous element has been + // sliced away from the buffered input and is "stashed" in + // another array on the side. + let input_times = TimestampNanosecondArray::from(vec![ + // Some(950), // output in last batch + // ^^^^^^^^^ this element has been sliced off + // 1000 // <-- cursor.next_ts + Some(1050), // context row + Some(1100), // <-- cursor.next_input_offset + // 1150 + // --- new series + None, // null timestamp + // 950 + Some(1000), + Some(1050), + Some(1100), + Some(1100), + ]) + .with_timezone_opt(TIME_DATA_TIMEZONE()); + let input_aggr_array: ArrayRef = Arc::new(Float64Array::from(vec![ + // Some(9.0) // 950 + // ^^^^^^^^^ this element has been sliced off + // 1000 // filled with stashed because missing + None, // 1050 // filled with stashed because null + Some(10.0), // 1100 // present + // 1150 // filled with previous because missing + // -- new series + Some(-20.0), // null timestamp + // 950 // null because no value for this series yet + None, // 1000 // still null + Some(21.1), + None, // 1100 // filled with previous because null value in column + None, // 1150 // filled with previous because null value in column + ])); + let series_ends = vec![2, 7]; + + let aggr_col_idx = 0; + let params = GapFillParams { + stride: 50, + first_ts: Some(950), + last_ts: 1150, + fill_strategy: prev_null_as_missing_fill_strategy(aggr_col_idx), + }; + + let stash: Float64Array = vec![None, Some(9.0)].into(); + let stash: ArrayRef = Arc::new(stash); + let output_batch_size = 10000; + let mut cursor = Cursor { + next_input_offset: 1, + next_ts: Some(1000), + remaining_output_batch_size: output_batch_size, + trailing_gaps: false, + aggr_col_states: std::iter::once(( + aggr_col_idx, + AggrColState::PrevNullAsMissingStashed { stash }, + )) + .collect(), + }; + + let time_arr = TimestampNanosecondArray::from( + cursor + .clone_for_aggr_col(None) + .unwrap() + .build_time_vec(¶ms, &series_ends, &input_times) + .unwrap(), + ) + .with_timezone_opt(TIME_DATA_TIMEZONE()); + let arr = cursor + .build_aggr_fill_prev_stashed(¶ms, &series_ends, &input_times, &input_aggr_array) + .unwrap(); + insta::assert_yaml_snapshot!(array_to_lines(&time_arr, &arr), @r###" + --- + - +--------------------------------+-------+ + - "| time | a0 |" + - +--------------------------------+-------+ + - "| 1970-01-01T00:00:00.000001Z | 9.0 |" + - "| 1970-01-01T00:00:00.000001050Z | 9.0 |" + - "| 1970-01-01T00:00:00.000001100Z | 10.0 |" + - "| 1970-01-01T00:00:00.000001150Z | 10.0 |" + - "| | -20.0 |" + - "| 1970-01-01T00:00:00.000000950Z | |" + - "| 1970-01-01T00:00:00.000001Z | |" + - "| 1970-01-01T00:00:00.000001050Z | 21.1 |" + - "| 1970-01-01T00:00:00.000001100Z | 21.1 |" + - "| 1970-01-01T00:00:00.000001150Z | 21.1 |" + - +--------------------------------+-------+ + "###); + + assert_cursor_end_state(&cursor, &input_times, ¶ms); + } + + pub(crate) fn array_to_lines( + time_array: &TimestampNanosecondArray, + aggr_array: &ArrayRef, + ) -> Vec { + let data_type = aggr_array.data_type().clone(); + let schema = Schema::new(vec![ + Field::new( + "time".to_string(), + (&InfluxColumnType::Timestamp).into(), + true, + ), + Field::new("a0".to_string(), data_type, true), + ]); + + let time_array: ArrayRef = Arc::new(time_array.clone()); + let arrays = vec![time_array, Arc::clone(aggr_array)]; + let rb = RecordBatch::try_new(Arc::new(schema), arrays).unwrap(); + batches_to_lines(&[rb]) + } + + pub(crate) fn new_cursor_with_batch_size(params: &GapFillParams, batch_size: usize) -> Cursor { + let mut cursor = Cursor::new(params); + cursor.remaining_output_batch_size = batch_size; + cursor + } + + pub(crate) fn assert_cursor_end_state( + cursor: &Cursor, + input_times: &TimestampNanosecondArray, + params: &GapFillParams, + ) { + assert_eq!(input_times.len(), cursor.next_input_offset); + assert_eq!(params.last_ts + params.stride, cursor.next_ts.unwrap()); + } + + fn simple_fill_strategy() -> HashMap { + std::iter::once((1, FillStrategy::Null)).collect() + } + + fn prev_fill_strategy(idx: usize) -> HashMap { + std::iter::once((idx, FillStrategy::PrevNullAsIntentional)).collect() + } + + fn prev_null_as_missing_fill_strategy(idx: usize) -> HashMap { + std::iter::once((idx, FillStrategy::PrevNullAsMissing)).collect() + } +} diff --git a/iox_query/src/exec/gapfill/algo/interpolate.rs b/iox_query/src/exec/gapfill/algo/interpolate.rs new file mode 100644 index 0000000..277e01b --- /dev/null +++ b/iox_query/src/exec/gapfill/algo/interpolate.rs @@ -0,0 +1,592 @@ +//! Filling gaps with interpolated values. +use std::sync::Arc; + +use arrow::{ + array::{ + as_primitive_array, as_struct_array, Array, ArrayRef, PrimitiveArray, StructArray, + TimestampNanosecondArray, + }, + datatypes::{ArrowPrimitiveType, DataType, Float64Type, Int64Type, UInt64Type}, +}; + +use crate::exec::gapfill::params::GapFillParams; + +use datafusion::{ + error::{DataFusionError, Result}, + scalar::ScalarValue, +}; + +use super::{AggrColState, Cursor, RowStatus, VecBuilder}; + +/// [Cursor] methods that are related to interpolation. +impl Cursor { + /// Create an Arrow array with gaps filled in between values + /// using linear interpolation. + pub(super) fn build_aggr_fill_interpolate( + &mut self, + params: &GapFillParams, + series_ends: &[usize], + input_time_array: &TimestampNanosecondArray, + input_aggr_array: &ArrayRef, + ) -> Result { + match input_aggr_array.data_type() { + DataType::Int64 => { + let input_aggr_array = as_primitive_array::(input_aggr_array); + self.build_aggr_fill_interpolate_typed( + params, + series_ends, + input_time_array, + input_aggr_array, + ) + } + DataType::UInt64 => { + let input_aggr_array = as_primitive_array::(input_aggr_array); + self.build_aggr_fill_interpolate_typed( + params, + series_ends, + input_time_array, + input_aggr_array, + ) + } + DataType::Float64 => { + let input_aggr_array = as_primitive_array::(input_aggr_array); + self.build_aggr_fill_interpolate_typed( + params, + series_ends, + input_time_array, + input_aggr_array, + ) + } + DataType::Struct(_) => { + // The only struct type that is expected is the one produced by the + // selector_* functions. These consist of a value, a timestamp and a + // number of associated values selected from the same row. When + // interpolating it is only the value field that will be interpolated. + // All other columns in the structure are filled with nulls. + + let input_aggr_array = as_struct_array(input_aggr_array); + let (fields, arrays, _) = input_aggr_array.clone().into_parts(); + let cursors = fields + .iter() + .map(|f| { + if f.name() == "value" { + // The "value" array uses the parent cursor. + Ok(None) + } else { + Ok(Some(self.clone_for_aggr_col(None)?)) + } + }) + .collect::>>>()?; + let new_arrays = cursors + .into_iter() + .zip(arrays.into_iter()) + .map(|(cursor, a)| { + if let Some(mut c) = cursor { + c.build_aggr_fill_null(params, series_ends, input_time_array, &a) + } else { + self.build_aggr_fill_interpolate( + params, + series_ends, + input_time_array, + &a, + ) + } + }) + .collect::>>()?; + Ok(Arc::new(StructArray::new(fields, new_arrays, None))) + } + dt => Err(DataFusionError::Execution(format!( + "unsupported data type {dt} for interpolation gap filling" + ))), + } + } + + /// Create an Arrow array with gaps filled in between values + /// using linear interpolation. + /// + /// This method has a template parameter and so accepts Arrow arrays of either + /// [Int64Array], [UInt64Array], or [Float64Array]. + /// + /// [Int64Array]: arrow::array::Int64Array + /// [UInt64Array]: arrow::array::UInt64Array + /// [Float64Array]: arrow::array::Float64Array + pub(super) fn build_aggr_fill_interpolate_typed( + &mut self, + params: &GapFillParams, + series_ends: &[usize], + input_time_array: &TimestampNanosecondArray, + input_aggr_array: &PrimitiveArray, + ) -> Result + where + T: ArrowPrimitiveType, + T::Native: LinearInterpolate, + PrimitiveArray: From>>, + Segment: TryFrom, Error = DataFusionError>, + Segment: From>, + { + let segment = self + .get_aggr_col_state() + .segment() + .as_ref() + .map(|seg| Segment::::try_from(seg.clone())) + .transpose()?; + let mut builder = InterpolateBuilder { + values: Vec::with_capacity(self.remaining_output_batch_size), + segment, + input_time_array, + input_aggr_array, + }; + self.build_vec(params, input_time_array, series_ends, &mut builder)?; + + let segment: Option> = builder.segment.clone().map(|seg| seg.into()); + self.set_aggr_col_state(AggrColState::LinearInterpolate(segment)); + let array: PrimitiveArray = builder.values.into(); + Ok(Arc::new(array)) + } +} + +/// Represents two non-null data values at two points in time, where the +/// gap between them must be fulled. The template parameter `T` stands in for +/// the type of the input aggregate column being filled. +#[derive(Clone, Debug)] +pub struct Segment { + start_point: (i64, T), + end_point: (i64, T), +} + +/// A macro to go from `Segment<$NATIVE>` into [`Segment`]. +/// Between output batches data values in segments are stored as [`ScalarValue`] +/// to avoid type parameters in [`Cursor`]. +macro_rules! impl_try_from_segment_native { + ($NATIVE:ident) => { + impl TryFrom> for Segment<$NATIVE> { + type Error = DataFusionError; + + fn try_from(segment: Segment) -> Result { + let Segment { + start_point: (start_ts, start_sv), + end_point: (end_ts, end_sv), + } = segment; + + let start_v = $NATIVE::try_from(start_sv)?; + let end_v = $NATIVE::try_from(end_sv)?; + Ok(Segment { + start_point: (start_ts, start_v), + end_point: (end_ts, end_v), + }) + } + } + }; +} + +impl_try_from_segment_native!(i64); +impl_try_from_segment_native!(u64); +impl_try_from_segment_native!(f64); + +/// A macro to go from [`Segment`] into `Segment<$NATIVE>`. +/// When producing an output batch, it's easiest to use the native type +/// to represent segments being filled. +macro_rules! impl_from_segment_scalar_value { + ($NATIVE:ident) => { + impl From> for Segment { + fn from(segment: Segment<$NATIVE>) -> Self { + let Segment { + start_point: (start_ts, start_native), + end_point: (end_ts, end_native), + } = segment; + + let start_v = ScalarValue::from(start_native); + let end_v = ScalarValue::from(end_native); + Segment { + start_point: (start_ts, start_v), + end_point: (end_ts, end_v), + } + } + } + }; +} + +impl_from_segment_scalar_value!(i64); +impl_from_segment_scalar_value!(u64); +impl_from_segment_scalar_value!(f64); + +/// Implements [`VecBuilder`] for build aggregate columns whose gaps +/// are being filled using linear interpolation. +pub(super) struct InterpolateBuilder<'a, T: ArrowPrimitiveType> { + pub values: Vec>, + pub segment: Option>, + pub input_time_array: &'a TimestampNanosecondArray, + pub input_aggr_array: &'a PrimitiveArray, +} + +impl<'a, T> VecBuilder for InterpolateBuilder<'a, T> +where + T: ArrowPrimitiveType, + T::Native: LinearInterpolate, +{ + fn push(&mut self, row_status: RowStatus) -> Result<()> { + match row_status { + RowStatus::NullTimestamp { offset, .. } => self.copy_point(offset), + RowStatus::Present { + ts, + offset, + series_end_offset, + } => { + if self.input_aggr_array.is_valid(offset) { + let end_offset = self.find_end_offset(offset, series_end_offset); + // Find the next non-null value in this column for the series. + // If there is one, start a new segment at the current value. + self.segment = end_offset.map(|end_offset| Segment { + start_point: (ts, self.input_aggr_array.value(offset)), + end_point: ( + self.input_time_array.value(end_offset), + self.input_aggr_array.value(end_offset), + ), + }); + self.copy_point(offset); + } else { + self.values.push( + self.segment + .as_ref() + .map(|seg| T::Native::interpolate(seg, ts)), + ); + } + } + RowStatus::Missing { ts, .. } => self.values.push( + self.segment + .as_ref() + .map(|seg| T::Native::interpolate(seg, ts)), + ), + } + Ok(()) + } + + fn start_new_series(&mut self) -> Result<()> { + self.segment = None; + Ok(()) + } +} + +impl InterpolateBuilder<'_, T> +where + T: ArrowPrimitiveType, +{ + /// Copies a point at `offset` into the vector that will be used to build + /// an Arrow array. + fn copy_point(&mut self, offset: usize) { + let v = self + .input_aggr_array + .is_valid(offset) + .then_some(self.input_aggr_array.value(offset)); + self.values.push(v) + } + + /// Scan forward to find the endpoint for a segment that starts at `start_offset`. + /// Skip over any null values. + /// + /// We are guaranteed to have buffered enough input to find the next non-null point for this series, + /// if there is one, by the logic in [`BufferedInput`]. + /// + /// [`BufferedInput`]: super::super::buffered_input::BufferedInput + fn find_end_offset(&self, start_offset: usize, series_end_offset: usize) -> Option { + ((start_offset + 1)..series_end_offset).find(|&i| self.input_aggr_array.is_valid(i)) + } +} + +/// A trait for the native numeric types that can be interpolated +/// by IOx. +/// +/// All implementations match what the +/// [1.8 Go implementation]() +/// of InfluxQL does. +pub(super) trait LinearInterpolate +where + Self: Sized, +{ + /// Given a [`Segment`] compute the value of the column at timestamp `ts`. + fn interpolate(segment: &Segment, ts: i64) -> Self; +} + +impl LinearInterpolate for i64 { + fn interpolate(segment: &Segment, ts: i64) -> Self { + let rise = (segment.end_point.1 - segment.start_point.1) as f64; + let run = (segment.end_point.0 - segment.start_point.0) as f64; + let m = rise / run; + let x = (ts - segment.start_point.0) as f64; + let b: f64 = segment.start_point.1 as f64; + (m * x + b) as Self + } +} + +impl LinearInterpolate for u64 { + fn interpolate(segment: &Segment, ts: i64) -> Self { + let rise = if segment.end_point.1 >= segment.start_point.1 { + (segment.end_point.1 - segment.start_point.1) as f64 + } else { + -(segment.end_point.1.abs_diff(segment.start_point.1) as f64) + }; + let run = (segment.end_point.0 - segment.start_point.0) as f64; + let m = rise / run; + let x = (ts - segment.start_point.0) as f64; + let b: f64 = segment.start_point.1 as f64; + (m * x + b) as Self + } +} + +impl LinearInterpolate for f64 { + fn interpolate(segment: &Segment, ts: i64) -> Self { + let rise = segment.end_point.1 - segment.start_point.1; + let run = (segment.end_point.0 - segment.start_point.0) as Self; + let m = rise / run; + let x = (ts - segment.start_point.0) as Self; + let b = segment.start_point.1; + m * x + b + } +} + +/// These tests verify that interpolation works as expected for each data type. +/// For comprehensive tests that handle multiple series and input/output +/// batches, see [crate::exec::gapfill::exec_tests]. +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{ArrayRef, Float64Array, Int64Array, TimestampNanosecondArray, UInt64Array}; + use hashbrown::HashMap; + use schema::TIME_DATA_TIMEZONE; + + use crate::exec::gapfill::{ + algo::tests::{array_to_lines, assert_cursor_end_state, new_cursor_with_batch_size}, + params::GapFillParams, + FillStrategy, + }; + + /// Verify the rounding behavior (really just truncating towards zero) which is + /// what InfluxQL does. Also verify that we can have a descending slope in the + /// line that does not overflow a `u64`. + #[test] + fn test_interpolate_u64() { + let input_times = TimestampNanosecondArray::from(vec![ + // 1000 + Some(1100), + // 1200 + // 1300 + Some(1400), + Some(1500), + // 1600 + Some(1700), + // 1800 + Some(1900), + // 2000 + ]); + let input_aggr_array: ArrayRef = Arc::new(UInt64Array::from(vec![ + // 1000 + Some(100), // 1100 + // 1200 + // 1300 + Some(200), // 1400 + None, // 1500 + // 1600 + Some(1000), // 1700 + // 1800 + Some(0), // 1900 + // 2000 + ])); + let series_ends = vec![input_times.len()]; + + let idx = 0; + let params = GapFillParams { + stride: 100, + first_ts: Some(1000), + last_ts: 2000, + fill_strategy: interpolate_fill_strategy(idx), + }; + + let output_batch_size = 10000; + let mut cursor = new_cursor_with_batch_size(¶ms, output_batch_size); + + let time_arr = TimestampNanosecondArray::from( + cursor + .clone_for_aggr_col(None) + .unwrap() + .build_time_vec(¶ms, &series_ends, &input_times) + .unwrap(), + ) + .with_timezone_opt(TIME_DATA_TIMEZONE()); + let arr = cursor + .build_aggr_fill_interpolate(¶ms, &series_ends, &input_times, &input_aggr_array) + .unwrap(); + insta::assert_yaml_snapshot!(array_to_lines(&time_arr, &arr), @r###" + --- + - +--------------------------------+------+ + - "| time | a0 |" + - +--------------------------------+------+ + - "| 1970-01-01T00:00:00.000001Z | |" + - "| 1970-01-01T00:00:00.000001100Z | 100 |" + - "| 1970-01-01T00:00:00.000001200Z | 133 |" + - "| 1970-01-01T00:00:00.000001300Z | 166 |" + - "| 1970-01-01T00:00:00.000001400Z | 200 |" + - "| 1970-01-01T00:00:00.000001500Z | 466 |" + - "| 1970-01-01T00:00:00.000001600Z | 733 |" + - "| 1970-01-01T00:00:00.000001700Z | 1000 |" + - "| 1970-01-01T00:00:00.000001800Z | 500 |" + - "| 1970-01-01T00:00:00.000001900Z | 0 |" + - "| 1970-01-01T00:00:00.000002Z | |" + - +--------------------------------+------+ + "###); + + assert_cursor_end_state(&cursor, &input_times, ¶ms); + } + + #[test] + fn test_interpolate_i64() { + let input_times = TimestampNanosecondArray::from(vec![ + // 1000 + Some(1100), + // 1200 + // 1300 + Some(1400), + Some(1500), + // 1600 + Some(1700), + // 1800 + Some(1900), + // 2000 + ]); + let input_aggr_array: ArrayRef = Arc::new(Int64Array::from(vec![ + // 1000 + Some(100), // 1100 + // 1200 + // 1300 + Some(200), // 1400 + None, // 1500 + // 1600 + Some(1000), // 1700 + // 1800 + Some(0), // 1900 + // 2000 + ])); + let series_ends = vec![input_times.len()]; + + let idx = 0; + let params = GapFillParams { + stride: 100, + first_ts: Some(1000), + last_ts: 2000, + fill_strategy: interpolate_fill_strategy(idx), + }; + + let output_batch_size = 10000; + let mut cursor = new_cursor_with_batch_size(¶ms, output_batch_size); + + let time_arr = TimestampNanosecondArray::from( + cursor + .clone_for_aggr_col(None) + .unwrap() + .build_time_vec(¶ms, &series_ends, &input_times) + .unwrap(), + ) + .with_timezone_opt(TIME_DATA_TIMEZONE()); + let arr = cursor + .build_aggr_fill_interpolate(¶ms, &series_ends, &input_times, &input_aggr_array) + .unwrap(); + insta::assert_yaml_snapshot!(array_to_lines(&time_arr, &arr), @r###" + --- + - +--------------------------------+------+ + - "| time | a0 |" + - +--------------------------------+------+ + - "| 1970-01-01T00:00:00.000001Z | |" + - "| 1970-01-01T00:00:00.000001100Z | 100 |" + - "| 1970-01-01T00:00:00.000001200Z | 133 |" + - "| 1970-01-01T00:00:00.000001300Z | 166 |" + - "| 1970-01-01T00:00:00.000001400Z | 200 |" + - "| 1970-01-01T00:00:00.000001500Z | 466 |" + - "| 1970-01-01T00:00:00.000001600Z | 733 |" + - "| 1970-01-01T00:00:00.000001700Z | 1000 |" + - "| 1970-01-01T00:00:00.000001800Z | 500 |" + - "| 1970-01-01T00:00:00.000001900Z | 0 |" + - "| 1970-01-01T00:00:00.000002Z | |" + - +--------------------------------+------+ + "###); + + assert_cursor_end_state(&cursor, &input_times, ¶ms); + } + + #[test] + fn test_interpolate_f64() { + let input_times = TimestampNanosecondArray::from(vec![ + // 1000 + Some(1100), + // 1200 + // 1300 + Some(1400), + Some(1500), + // 1600 + Some(1700), + // 1800 + Some(1900), + // 2000 + ]); + let input_aggr_array: ArrayRef = Arc::new(Float64Array::from(vec![ + // 1000 + Some(100.0), // 1100 + // 1200 + // 1300 + Some(400.0), // 1400 + None, // 1500 + // 1600 + Some(1000.0), // 1700 + // 1800 + Some(0.0), // 1900 + // 2000 + ])); + let series_ends = vec![input_times.len()]; + + let idx = 0; + let params = GapFillParams { + stride: 100, + first_ts: Some(1000), + last_ts: 2000, + fill_strategy: interpolate_fill_strategy(idx), + }; + + let output_batch_size = 10000; + let mut cursor = new_cursor_with_batch_size(¶ms, output_batch_size); + + let time_arr = TimestampNanosecondArray::from( + cursor + .clone_for_aggr_col(None) + .unwrap() + .build_time_vec(¶ms, &series_ends, &input_times) + .unwrap(), + ) + .with_timezone_opt(TIME_DATA_TIMEZONE()); + let arr = cursor + .build_aggr_fill_interpolate(¶ms, &series_ends, &input_times, &input_aggr_array) + .unwrap(); + insta::assert_yaml_snapshot!(array_to_lines(&time_arr, &arr), @r###" + --- + - +--------------------------------+--------+ + - "| time | a0 |" + - +--------------------------------+--------+ + - "| 1970-01-01T00:00:00.000001Z | |" + - "| 1970-01-01T00:00:00.000001100Z | 100.0 |" + - "| 1970-01-01T00:00:00.000001200Z | 200.0 |" + - "| 1970-01-01T00:00:00.000001300Z | 300.0 |" + - "| 1970-01-01T00:00:00.000001400Z | 400.0 |" + - "| 1970-01-01T00:00:00.000001500Z | 600.0 |" + - "| 1970-01-01T00:00:00.000001600Z | 800.0 |" + - "| 1970-01-01T00:00:00.000001700Z | 1000.0 |" + - "| 1970-01-01T00:00:00.000001800Z | 500.0 |" + - "| 1970-01-01T00:00:00.000001900Z | 0.0 |" + - "| 1970-01-01T00:00:00.000002Z | |" + - +--------------------------------+--------+ + "###); + + assert_cursor_end_state(&cursor, &input_times, ¶ms); + } + + fn interpolate_fill_strategy(idx: usize) -> HashMap { + std::iter::once((idx, FillStrategy::LinearInterpolate)).collect() + } +} diff --git a/iox_query/src/exec/gapfill/buffered_input.rs b/iox_query/src/exec/gapfill/buffered_input.rs new file mode 100644 index 0000000..59ae311 --- /dev/null +++ b/iox_query/src/exec/gapfill/buffered_input.rs @@ -0,0 +1,502 @@ +//! Logic for buffering record batches for gap filling. + +use std::sync::Arc; + +use arrow::{ + array::{as_struct_array, ArrayRef}, + datatypes::DataType, + record_batch::RecordBatch, + row::{RowConverter, Rows, SortField}, +}; +use datafusion::error::{DataFusionError, Result}; +use hashbrown::HashSet; + +use super::{params::GapFillParams, FillStrategy}; + +/// Encapsulate the logic around how to buffer input records. +/// +/// If there are no columns with [`FillStrategy::LinearInterpolate`], then +/// we need to buffer up to the last input row that might appear in the output, plus +/// one additional row. +/// +/// However, if there are columns filled via interpolation, then we need +/// to ensure that we read ahead far enough to a non-null value, or a change +/// of group columns, in the columns being interpolated. +/// +/// [`FillStrategy::LinearInterpolate`]: super::FillStrategy::LinearInterpolate +/// [`GapFillStream`]: super::stream::GapFillStream +pub(super) struct BufferedInput { + /// Indexes of group columns in the schema (not including time). + group_cols: Vec, + /// Indexes of aggregate columns filled via interpolation. + interpolate_cols: Vec, + /// Buffered records from the input stream. + batches: Vec, + /// When gap filling with interpolated values, this row converter + /// is used to compare rows to see if group columns have changed. + row_converter: Option, + /// When gap filling with interpolated values, cache a row-oriented + /// representation of the last row that may appear in the output so + /// it doesn't need to be computed more than once. + last_output_row: Option, +} + +impl BufferedInput { + pub(super) fn new(params: &GapFillParams, group_cols: Vec) -> Self { + let interpolate_cols = params + .fill_strategy + .iter() + .filter_map(|(col_offset, fs)| { + (fs == &FillStrategy::LinearInterpolate).then_some(*col_offset) + }) + .collect::>(); + Self { + group_cols, + interpolate_cols, + batches: vec![], + row_converter: None, + last_output_row: None, + } + } + /// Add a new batch of buffered records from the input stream. + pub(super) fn push(&mut self, batch: RecordBatch) { + self.batches.push(batch); + } + + /// Transfer ownership of the buffered record batches to the caller for + /// processing. + pub(super) fn take(&mut self) -> Vec { + self.last_output_row = None; + std::mem::take(&mut self.batches) + } + + /// Determine if we need more input before we start processing. + pub(super) fn need_more(&mut self, last_output_row_offset: usize) -> Result { + let record_count: usize = self.batches.iter().map(|rb| rb.num_rows()).sum(); + // min number of rows needed is the number of rows up to and including + // the last row that may appear in the output, plus one more row. + let min_needed = last_output_row_offset + 2; + + if record_count < min_needed { + return Ok(true); + } else if self.interpolate_cols.is_empty() { + return Ok(false); + } + + // Check to see if the last row that might appear in the output + // has a different group column values than the last buffered row. + // If they are different, then we have enough input to start. + let (last_output_batch_offset, last_output_row_offset) = self + .find_row_idx(last_output_row_offset) + .expect("checked record count"); + if self.group_columns_changed((last_output_batch_offset, last_output_row_offset))? { + return Ok(false); + } + + // Now check if there are non-null values in the columns being interpolated. + // We skip over the batches that come before the one that contains the last + // possible output row. We start with the last buffered batch, so we can avoid + // having to slice unless necessary. + let mut cols_that_need_more = + HashSet::::from_iter(self.interpolate_cols.iter().cloned()); + let mut to_remove = vec![]; + for (i, batch) in self + .batches + .iter() + .enumerate() + .skip(last_output_batch_offset) + .rev() + { + for col_offset in cols_that_need_more.clone() { + // If this is the batch containing the last possible output row, slice the + // array so we are just looking at that value and the ones after. + let array = batch.column(col_offset); + let array = if i == last_output_batch_offset { + let length = array.len() - last_output_row_offset; + batch + .column(col_offset) + .slice(last_output_row_offset, length) + } else { + Arc::clone(array) + }; + + let struct_value_col = if let DataType::Struct(fields) = array.data_type().clone() { + fields.find("value").map(|(n, _)| n) + } else { + None + }; + + match struct_value_col { + Some(n) => { + let value_array = as_struct_array(&array).column(n); + if array.null_count() < array.len() + && value_array.null_count() < value_array.len() + { + to_remove.push(col_offset); + } + } + None => { + if array.null_count() < array.len() { + to_remove.push(col_offset); + } + } + } + } + + to_remove.drain(..).for_each(|c| { + cols_that_need_more.remove(&c); + }); + if cols_that_need_more.is_empty() { + break; + } + } + + Ok(!cols_that_need_more.is_empty()) + } + + /// Check to see if the group column values have changed between the last row + /// that may be in the output and the last buffered input row. + /// + /// This method uses the row-oriented representation of Arrow data from [`arrow::row`] to + /// compare rows in different record batches. + /// + /// [`arrow::row`]: https://docs.rs/arrow-row/36.0.0/arrow_row/index.html + fn group_columns_changed(&mut self, last_output_row_idx: (usize, usize)) -> Result { + if self.group_cols.is_empty() { + return Ok(false); + } + + let last_buffered_row_idx = self.last_buffered_row_idx(); + if last_output_row_idx == last_buffered_row_idx { + // the output row is also the last buffered row, + // so there is nothing to compare. + return Ok(false); + } + + let last_input_rows = self.convert_row(self.last_buffered_row_idx())?; + let last_row_in_output = self.last_output_row(last_output_row_idx)?; + + Ok(last_row_in_output.row(0) != last_input_rows.row(0)) + } + + /// Get a row converter for comparing records. Keep it in [`Self::row_converter`] + /// to avoid creating it multiple times. + fn get_row_converter(&mut self) -> Result<&mut RowConverter> { + if self.row_converter.is_none() { + let batch = self.batches.first().expect("at least one batch"); + let sort_fields = self + .group_cols + .iter() + .map(|c| SortField::new(batch.column(*c).data_type().clone())) + .collect(); + let row_converter = RowConverter::new(sort_fields) + .map_err(|err| DataFusionError::ArrowError(err, None))?; + self.row_converter = Some(row_converter); + } + Ok(self.row_converter.as_mut().expect("cannot be none")) + } + + /// Convert a row to row-oriented format for easy comparison. + fn convert_row(&mut self, row_idxs: (usize, usize)) -> Result { + let batch = &self.batches[row_idxs.0]; + let columns: Vec = self + .group_cols + .iter() + .map(|col_idx| batch.column(*col_idx).slice(row_idxs.1, 1)) + .collect(); + self.get_row_converter()? + .convert_columns(&columns) + .map_err(|err| DataFusionError::ArrowError(err, None)) + } + + /// Returns the row-oriented representation of the last buffered row that may appear in the next + /// output batch. Since this row may be used multiple times, cache it in `self` to + /// avoid computing it multiple times. + fn last_output_row(&mut self, idxs: (usize, usize)) -> Result<&Rows> { + if self.last_output_row.is_none() { + let rows = self.convert_row(idxs)?; + self.last_output_row = Some(rows); + } + Ok(self.last_output_row.as_ref().expect("cannot be none")) + } + + /// Return the `(batch_idx, row_idx)` of the last buffered row. + fn last_buffered_row_idx(&self) -> (usize, usize) { + let last_batch_len = self.batches.last().unwrap().num_rows(); + (self.batches.len() - 1, last_batch_len - 1) + } + + /// Return the `(batch_idx, row_idx)` of the `nth` row. + fn find_row_idx(&self, mut nth: usize) -> Option<(usize, usize)> { + let mut idx = None; + for (i, batch) in self.batches.iter().enumerate() { + if nth >= batch.num_rows() { + nth -= batch.num_rows() + } else { + idx = Some((i, nth)); + break; + } + } + idx + } +} + +#[cfg(test)] +mod tests { + use std::collections::VecDeque; + + use arrow_util::test_util::batches_to_lines; + + use super::*; + use crate::exec::gapfill::exec_tests::TestRecords; + + fn test_records(batch_size: usize) -> VecDeque { + let records = TestRecords { + group_cols: vec![ + std::iter::repeat(Some("a")).take(12).collect(), + std::iter::repeat(Some("b")) + .take(6) + .chain(std::iter::repeat(Some("c")).take(6)) + .collect(), + ], + time_col: (0..12).map(|i| Some(1000 + i * 5)).take(12).collect(), + timezone: None, + agg_cols: vec![ + vec![ + Some(1), + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + Some(10), + ], + vec![ + Some(2), + None, + None, + None, + None, + None, + None, + None, + Some(20), + None, + None, + None, + ], + (0..12).map(Some).collect(), + ], + struct_cols: vec![], + input_batch_size: batch_size, + }; + + TryInto::>::try_into(records) + .unwrap() + .into() + } + + fn test_struct_records(batch_size: usize) -> VecDeque { + let records = TestRecords { + group_cols: vec![ + std::iter::repeat(Some("a")).take(12).collect(), + std::iter::repeat(Some("b")) + .take(6) + .chain(std::iter::repeat(Some("c")).take(6)) + .collect(), + ], + time_col: (0..12).map(|i| Some(1000 + i * 5)).take(12).collect(), + timezone: None, + agg_cols: vec![], + struct_cols: vec![ + vec![ + Some(vec![1, 0]), + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + Some(vec![10, 0]), + ], + vec![ + Some(vec![2, 0]), + None, + None, + None, + None, + None, + None, + None, + Some(vec![20, 0]), + None, + None, + None, + ], + (0..12).map(|n| Some(vec![n, 0])).collect(), + ], + input_batch_size: batch_size, + }; + + TryInto::>::try_into(records) + .unwrap() + .into() + } + + fn test_params() -> GapFillParams { + GapFillParams { + stride: 50_000_000, + first_ts: Some(1_000_000_000), + last_ts: 1_055_000_000, + fill_strategy: [ + (3, FillStrategy::LinearInterpolate), + (4, FillStrategy::LinearInterpolate), + ] + .into(), + } + } + + // This test is just here so it's clear what the + // test data is + #[test] + fn test_test_records() { + let batch = test_records(1000).pop_front().unwrap(); + let actual = batches_to_lines(&[batch]); + insta::assert_yaml_snapshot!(actual, @r###" + --- + - +----+----+--------------------------+----+----+----+ + - "| g0 | g1 | time | a0 | a1 | a2 |" + - +----+----+--------------------------+----+----+----+ + - "| a | b | 1970-01-01T00:00:01Z | 1 | 2 | 0 |" + - "| a | b | 1970-01-01T00:00:01.005Z | | | 1 |" + - "| a | b | 1970-01-01T00:00:01.010Z | | | 2 |" + - "| a | b | 1970-01-01T00:00:01.015Z | | | 3 |" + - "| a | b | 1970-01-01T00:00:01.020Z | | | 4 |" + - "| a | b | 1970-01-01T00:00:01.025Z | | | 5 |" + - "| a | c | 1970-01-01T00:00:01.030Z | | | 6 |" + - "| a | c | 1970-01-01T00:00:01.035Z | | | 7 |" + - "| a | c | 1970-01-01T00:00:01.040Z | | 20 | 8 |" + - "| a | c | 1970-01-01T00:00:01.045Z | | | 9 |" + - "| a | c | 1970-01-01T00:00:01.050Z | | | 10 |" + - "| a | c | 1970-01-01T00:00:01.055Z | 10 | | 11 |" + - +----+----+--------------------------+----+----+----+ + "###); + } + + #[test] + fn no_group_no_interpolate() { + let batch_size = 3; + let mut params = test_params(); + params.fill_strategy = [].into(); + + let mut buffered_input = BufferedInput::new(¶ms, vec![]); + let mut batches = test_records(batch_size); + + // There are no rows, so that is less than the batch size, + // it needs more. + assert!(buffered_input.need_more(batch_size - 1).unwrap()); + + // There are now 3 rows, still less than batch_size + 1, + // so it needs more. + buffered_input.push(batches.pop_front().unwrap()); + assert!(buffered_input.need_more(batch_size - 1).unwrap()); + + // We now have batch_size * 2, records, which is enough. + buffered_input.push(batches.pop_front().unwrap()); + assert!(!buffered_input.need_more(batch_size - 1).unwrap()); + } + + #[test] + fn no_group() { + let batch_size = 3; + let params = test_params(); + let mut buffered_input = BufferedInput::new(¶ms, vec![]); + let mut batches = test_records(batch_size); + + // There are no rows, so that is less than the batch size, + // it needs more. + assert!(buffered_input.need_more(batch_size - 1).unwrap()); + + // There are now 3 rows, still less than batch_size + 1, + // so it needs more. + buffered_input.push(batches.pop_front().unwrap()); + assert!(buffered_input.need_more(batch_size - 1).unwrap()); + + // There are now 6 rows, if we were not interpolating, + // this would be enough. + buffered_input.push(batches.pop_front().unwrap()); + + // If we are interpolating, there are no non null values + // at offset 5. + assert!(buffered_input.need_more(batch_size - 1).unwrap()); + + // Push more rows, now totaling 9. + buffered_input.push(batches.pop_front().unwrap()); + assert!(buffered_input.need_more(batch_size - 1).unwrap()); + // Column `a1` has a non-null value at offset 8. + // If that were the only column being interpolated, we would have enough. + + // 12 rows, with non-null values in both columns being interpolated. + buffered_input.push(batches.pop_front().unwrap()); + assert!(!buffered_input.need_more(batch_size - 1).unwrap()); + } + + #[test] + fn with_group() { + let params = test_params(); + let group_cols = vec![0, 1]; + let mut buffered_input = BufferedInput::new(¶ms, group_cols); + + let batch_size = 3; + let mut batches = test_records(batch_size); + + // no rows + assert!(buffered_input.need_more(batch_size - 1).unwrap()); + + // 3 rows + buffered_input.push(batches.pop_front().unwrap()); + assert!(buffered_input.need_more(batch_size - 1).unwrap()); + + // 6 rows + buffered_input.push(batches.pop_front().unwrap()); + assert!(buffered_input.need_more(batch_size - 1).unwrap()); + + // 9 rows (series changes here) + buffered_input.push(batches.pop_front().unwrap()); + assert!(!buffered_input.need_more(batch_size - 1).unwrap()); + } + + #[test] + fn struct_with_group() { + let params = test_params(); + let group_cols = vec![0, 1]; + let mut buffered_input = BufferedInput::new(¶ms, group_cols); + + let batch_size = 3; + let mut batches = test_struct_records(batch_size); + + // no rows + assert!(buffered_input.need_more(batch_size - 1).unwrap()); + + // 3 rows + buffered_input.push(batches.pop_front().unwrap()); + assert!(buffered_input.need_more(batch_size - 1).unwrap()); + + // 6 rows + buffered_input.push(batches.pop_front().unwrap()); + assert!(buffered_input.need_more(batch_size - 1).unwrap()); + + // 9 rows (series changes here) + buffered_input.push(batches.pop_front().unwrap()); + assert!(!buffered_input.need_more(batch_size - 1).unwrap()); + } +} diff --git a/iox_query/src/exec/gapfill/exec_tests.rs b/iox_query/src/exec/gapfill/exec_tests.rs new file mode 100644 index 0000000..cc0a190 --- /dev/null +++ b/iox_query/src/exec/gapfill/exec_tests.rs @@ -0,0 +1,1619 @@ +//! Tests that verify output produced by [GapFillExec]. + +use std::{ + cmp::Ordering, + ops::{Bound, Range}, +}; + +use super::*; +use arrow::{ + array::{ArrayRef, DictionaryArray, Int64Array, StructArray, TimestampNanosecondArray}, + datatypes::{DataType, Field, Fields, Int32Type, Schema, TimeUnit}, + record_batch::RecordBatch, +}; +use arrow_util::test_util::batches_to_lines; +use datafusion::{ + error::Result, + execution::runtime_env::{RuntimeConfig, RuntimeEnv}, + physical_plan::{ + collect, expressions::col as phys_col, expressions::lit as phys_lit, memory::MemoryExec, + }, + prelude::{SessionConfig, SessionContext}, + scalar::ScalarValue, +}; +use futures::executor::block_on; +use observability_deps::tracing::debug; +use schema::{InfluxColumnType, InfluxFieldType}; +use test_helpers::assert_error; + +#[test] +fn test_gapfill_simple() { + test_helpers::maybe_start_logging(); + insta::allow_duplicates! { for output_batch_size in [1, 2, 4, 8] { + for input_batch_size in [1, 2] { + let batch = TestRecords { + group_cols: vec![vec![Some("a"), Some("a")]], + time_col: vec![Some(1_000), Some(1_100)], + timezone: None, + agg_cols: vec![vec![Some(10), Some(11)]], + struct_cols: vec![], + input_batch_size, + }; + let params = get_params_ms(&batch, 25, Some(975), 1_125); + let tc = TestCase { + test_records: batch, + output_batch_size, + params, + }; + // For this simple test case, also test that + // memory is tracked correctly, which is done by + // TestCase when running with a memory limit. + let batches = tc.run_with_memory_limit(16384).unwrap(); + let actual = batches_to_lines(&batches); + insta::assert_yaml_snapshot!(actual, @r###" + --- + - +----+--------------------------+----+ + - "| g0 | time | a0 |" + - +----+--------------------------+----+ + - "| a | 1970-01-01T00:00:00.975Z | |" + - "| a | 1970-01-01T00:00:01Z | 10 |" + - "| a | 1970-01-01T00:00:01.025Z | |" + - "| a | 1970-01-01T00:00:01.050Z | |" + - "| a | 1970-01-01T00:00:01.075Z | |" + - "| a | 1970-01-01T00:00:01.100Z | 11 |" + - "| a | 1970-01-01T00:00:01.125Z | |" + - +----+--------------------------+----+ + "###); + assert_batch_count(&batches, output_batch_size); + } + }} +} + +#[test] +fn test_gapfill_simple_tz() { + test_helpers::maybe_start_logging(); + insta::allow_duplicates! { for output_batch_size in [1, 2, 4, 8] { + for input_batch_size in [1, 2] { + let batch = TestRecords { + group_cols: vec![vec![Some("a"), Some("a")]], + time_col: vec![Some(1_000), Some(1_100)], + timezone: Some("Australia/Adelaide".into()), + agg_cols: vec![vec![Some(10), Some(11)]], + struct_cols: vec![], + input_batch_size, + }; + let params = get_params_ms(&batch, 25, Some(975), 1_125); + let tc = TestCase { + test_records: batch, + output_batch_size, + params, + }; + // For this simple test case, also test that + // memory is tracked correctly, which is done by + // TestCase when running with a memory limit. + let batches = tc.run_with_memory_limit(16384).unwrap(); + let actual = batches_to_lines(&batches); + insta::assert_yaml_snapshot!(actual, @r###" + --- + - +----+-------------------------------+----+ + - "| g0 | time | a0 |" + - +----+-------------------------------+----+ + - "| a | 1970-01-01T09:30:00.975+09:30 | |" + - "| a | 1970-01-01T09:30:01+09:30 | 10 |" + - "| a | 1970-01-01T09:30:01.025+09:30 | |" + - "| a | 1970-01-01T09:30:01.050+09:30 | |" + - "| a | 1970-01-01T09:30:01.075+09:30 | |" + - "| a | 1970-01-01T09:30:01.100+09:30 | 11 |" + - "| a | 1970-01-01T09:30:01.125+09:30 | |" + - +----+-------------------------------+----+ + "###); + assert_batch_count(&batches, output_batch_size); + } + }} +} + +#[test] +fn test_gapfill_simple_no_group_no_aggr() { + // There may be no group columns in a gap fill query, + // and there may be no aggregate columns as well. + // Such a query is not all that useful but it should work. + test_helpers::maybe_start_logging(); + insta::allow_duplicates! { for output_batch_size in [1, 2, 4, 8] { + for input_batch_size in [1, 2, 4] { + let batch = TestRecords { + group_cols: vec![], + time_col: vec![None, Some(1_000), Some(1_100)], + timezone: None, + agg_cols: vec![], + struct_cols: vec![], + input_batch_size, + }; + let params = get_params_ms(&batch, 25, Some(975), 1_125); + let tc = TestCase { + test_records: batch, + output_batch_size, + params, + }; + let batches = tc.run().unwrap(); + let actual = batches_to_lines(&batches); + insta::assert_yaml_snapshot!(actual, @r###" + --- + - +--------------------------+ + - "| time |" + - +--------------------------+ + - "| |" + - "| 1970-01-01T00:00:00.975Z |" + - "| 1970-01-01T00:00:01Z |" + - "| 1970-01-01T00:00:01.025Z |" + - "| 1970-01-01T00:00:01.050Z |" + - "| 1970-01-01T00:00:01.075Z |" + - "| 1970-01-01T00:00:01.100Z |" + - "| 1970-01-01T00:00:01.125Z |" + - +--------------------------+ + "###); + assert_batch_count(&batches, output_batch_size); + } + }} +} + +#[test] +fn test_gapfill_multi_group_simple() { + test_helpers::maybe_start_logging(); + insta::allow_duplicates! { for output_batch_size in [1, 2, 4, 8, 16] { + for input_batch_size in [1, 2, 4] { + let records = TestRecords { + group_cols: vec![vec![Some("a"), Some("a"), Some("b"), Some("b")]], + time_col: vec![Some(1_000), Some(1_100), Some(1_025), Some(1_050)], + timezone: None, + agg_cols: vec![vec![Some(10), Some(11), Some(20), Some(21)]], + struct_cols: vec![], + input_batch_size, + }; + let params = get_params_ms(&records, 25, Some(975), 1_125); + let tc = TestCase { + test_records: records, + output_batch_size, + params, + }; + let batches = tc.run().unwrap(); + let actual = batches_to_lines(&batches); + insta::assert_yaml_snapshot!(actual, @r###" + --- + - +----+--------------------------+----+ + - "| g0 | time | a0 |" + - +----+--------------------------+----+ + - "| a | 1970-01-01T00:00:00.975Z | |" + - "| a | 1970-01-01T00:00:01Z | 10 |" + - "| a | 1970-01-01T00:00:01.025Z | |" + - "| a | 1970-01-01T00:00:01.050Z | |" + - "| a | 1970-01-01T00:00:01.075Z | |" + - "| a | 1970-01-01T00:00:01.100Z | 11 |" + - "| a | 1970-01-01T00:00:01.125Z | |" + - "| b | 1970-01-01T00:00:00.975Z | |" + - "| b | 1970-01-01T00:00:01Z | |" + - "| b | 1970-01-01T00:00:01.025Z | 20 |" + - "| b | 1970-01-01T00:00:01.050Z | 21 |" + - "| b | 1970-01-01T00:00:01.075Z | |" + - "| b | 1970-01-01T00:00:01.100Z | |" + - "| b | 1970-01-01T00:00:01.125Z | |" + - +----+--------------------------+----+ + "###); + assert_batch_count(&batches, output_batch_size); + } + }} +} + +#[test] +fn test_gapfill_multi_group_simple_origin() { + test_helpers::maybe_start_logging(); + insta::allow_duplicates! { for output_batch_size in [1, 2, 4, 8, 16] { + for input_batch_size in [1, 2, 4] { + let records = TestRecords { + group_cols: vec![vec![Some("a"), Some("a"), Some("b"), Some("b")]], + time_col: vec![Some(1_000), Some(1_100), Some(1_025), Some(1_050)], + timezone: None, + agg_cols: vec![vec![Some(10), Some(11), Some(20), Some(21)]], + struct_cols: vec![], + input_batch_size, + }; + let params = get_params_ms_with_origin_fill_strategy(&records, 25, Some(975), 1_125, Some(3), FillStrategy::Null); + let tc = TestCase { + test_records: records, + output_batch_size, + params, + }; + let batches = tc.run().unwrap(); + let actual = batches_to_lines(&batches); + // timestamps are now offset by 3ms + insta::assert_yaml_snapshot!(actual, @r###" + --- + - +----+--------------------------+----+ + - "| g0 | time | a0 |" + - +----+--------------------------+----+ + - "| a | 1970-01-01T00:00:00.953Z | |" + - "| a | 1970-01-01T00:00:00.978Z | |" + - "| a | 1970-01-01T00:00:01.003Z | 10 |" + - "| a | 1970-01-01T00:00:01.028Z | |" + - "| a | 1970-01-01T00:00:01.053Z | |" + - "| a | 1970-01-01T00:00:01.078Z | |" + - "| a | 1970-01-01T00:00:01.103Z | 11 |" + - "| b | 1970-01-01T00:00:00.953Z | |" + - "| b | 1970-01-01T00:00:00.978Z | |" + - "| b | 1970-01-01T00:00:01.003Z | |" + - "| b | 1970-01-01T00:00:01.028Z | 20 |" + - "| b | 1970-01-01T00:00:01.053Z | 21 |" + - "| b | 1970-01-01T00:00:01.078Z | |" + - "| b | 1970-01-01T00:00:01.103Z | |" + - +----+--------------------------+----+ + "###); + assert_batch_count(&batches, output_batch_size); + } + }} +} + +#[test] +fn test_gapfill_multi_group_with_nulls() { + test_helpers::maybe_start_logging(); + insta::allow_duplicates! { for output_batch_size in [1, 2, 4, 8, 16, 32] { + for input_batch_size in [1, 2, 4, 8] { + let records = TestRecords { + group_cols: vec![vec![ + Some("a"), + Some("a"), + Some("a"), + Some("a"), + Some("b"), + Some("b"), + Some("b"), + ]], + time_col: vec![ + None, + None, + Some(1_000), + Some(1_100), + None, + Some(1_000), + Some(1_100), + ], + timezone: None, + agg_cols: vec![vec![ + Some(1), + None, + Some(10), + Some(11), + Some(2), + Some(20), + Some(21), + ]], + struct_cols: vec![], + input_batch_size, + }; + let params = get_params_ms(&records, 25, Some(975), 1_125); + let tc = TestCase { + test_records: records, + output_batch_size, + params, + }; + let batches = tc.run().unwrap(); + let actual = batches_to_lines(&batches); + insta::assert_yaml_snapshot!(actual, @r###" + --- + - +----+--------------------------+----+ + - "| g0 | time | a0 |" + - +----+--------------------------+----+ + - "| a | | 1 |" + - "| a | | |" + - "| a | 1970-01-01T00:00:00.975Z | |" + - "| a | 1970-01-01T00:00:01Z | 10 |" + - "| a | 1970-01-01T00:00:01.025Z | |" + - "| a | 1970-01-01T00:00:01.050Z | |" + - "| a | 1970-01-01T00:00:01.075Z | |" + - "| a | 1970-01-01T00:00:01.100Z | 11 |" + - "| a | 1970-01-01T00:00:01.125Z | |" + - "| b | | 2 |" + - "| b | 1970-01-01T00:00:00.975Z | |" + - "| b | 1970-01-01T00:00:01Z | 20 |" + - "| b | 1970-01-01T00:00:01.025Z | |" + - "| b | 1970-01-01T00:00:01.050Z | |" + - "| b | 1970-01-01T00:00:01.075Z | |" + - "| b | 1970-01-01T00:00:01.100Z | 21 |" + - "| b | 1970-01-01T00:00:01.125Z | |" + - +----+--------------------------+----+ + "###); + assert_batch_count(&batches, output_batch_size); + } + }} +} + +#[test] +fn test_gapfill_multi_group_cols_with_nulls() { + test_helpers::maybe_start_logging(); + insta::allow_duplicates! { for output_batch_size in [1, 2, 4, 8, 16, 32] { + for input_batch_size in [1, 2, 4, 8] { + let records = TestRecords { + group_cols: vec![ + vec![ + Some("a"), + Some("a"), + Some("a"), + Some("a"), + Some("a"), + Some("a"), + Some("a"), + ], + vec![ + Some("c"), + Some("c"), + Some("c"), + Some("c"), + Some("d"), + Some("d"), + Some("d"), + ], + ], + time_col: vec![ + None, + None, + Some(1_000), + Some(1_100), + None, + Some(1_000), + Some(1_100), + ], + timezone: None, + agg_cols: vec![vec![ + Some(1), + None, + Some(10), + Some(11), + Some(2), + Some(20), + Some(21), + ]], + struct_cols: vec![], + input_batch_size, + }; + let params = get_params_ms(&records, 25, Some(975), 1_125); + let tc = TestCase { + test_records: records, + output_batch_size, + params, + }; + let batches = tc.run().unwrap(); + let actual = batches_to_lines(&batches); + insta::assert_yaml_snapshot!(actual, @r###" + --- + - +----+----+--------------------------+----+ + - "| g0 | g1 | time | a0 |" + - +----+----+--------------------------+----+ + - "| a | c | | 1 |" + - "| a | c | | |" + - "| a | c | 1970-01-01T00:00:00.975Z | |" + - "| a | c | 1970-01-01T00:00:01Z | 10 |" + - "| a | c | 1970-01-01T00:00:01.025Z | |" + - "| a | c | 1970-01-01T00:00:01.050Z | |" + - "| a | c | 1970-01-01T00:00:01.075Z | |" + - "| a | c | 1970-01-01T00:00:01.100Z | 11 |" + - "| a | c | 1970-01-01T00:00:01.125Z | |" + - "| a | d | | 2 |" + - "| a | d | 1970-01-01T00:00:00.975Z | |" + - "| a | d | 1970-01-01T00:00:01Z | 20 |" + - "| a | d | 1970-01-01T00:00:01.025Z | |" + - "| a | d | 1970-01-01T00:00:01.050Z | |" + - "| a | d | 1970-01-01T00:00:01.075Z | |" + - "| a | d | 1970-01-01T00:00:01.100Z | 21 |" + - "| a | d | 1970-01-01T00:00:01.125Z | |" + - +----+----+--------------------------+----+ + "###); + assert_batch_count(&batches, output_batch_size); + } + }} +} + +#[test] +fn test_gapfill_multi_group_cols_with_more_nulls() { + test_helpers::maybe_start_logging(); + insta::allow_duplicates! { for output_batch_size in [1, 2, 4, 8, 16, 32] { + for input_batch_size in [1, 2, 4, 8] { + let records = TestRecords { + group_cols: vec![vec![Some("a"), Some("b"), Some("b"), Some("b"), Some("b")]], + time_col: vec![ + Some(1_000), + None, // group b + None, + None, + None, + ], + timezone: None, + agg_cols: vec![vec![ + Some(10), // group a + Some(90), // group b + Some(91), + Some(92), + Some(93), + ]], + struct_cols: vec![], + input_batch_size, + }; + let params = get_params_ms(&records, 25, Some(975), 1_025); + let tc = TestCase { + test_records: records, + output_batch_size, + params, + }; + let batches = tc.run().unwrap(); + let actual = batches_to_lines(&batches); + insta::assert_yaml_snapshot!(actual, @r###" + --- + - +----+--------------------------+----+ + - "| g0 | time | a0 |" + - +----+--------------------------+----+ + - "| a | 1970-01-01T00:00:00.975Z | |" + - "| a | 1970-01-01T00:00:01Z | 10 |" + - "| a | 1970-01-01T00:00:01.025Z | |" + - "| b | | 90 |" + - "| b | | 91 |" + - "| b | | 92 |" + - "| b | | 93 |" + - "| b | 1970-01-01T00:00:00.975Z | |" + - "| b | 1970-01-01T00:00:01Z | |" + - "| b | 1970-01-01T00:00:01.025Z | |" + - +----+--------------------------+----+ + "###); + assert_batch_count(&batches, output_batch_size); + } + }} +} + +#[test] +fn test_gapfill_multi_aggr_cols_with_nulls() { + test_helpers::maybe_start_logging(); + insta::allow_duplicates! { for output_batch_size in [1, 2, 4, 8, 16, 32] { + for input_batch_size in [1, 2, 4, 8] { + let records = TestRecords { + group_cols: vec![ + vec![ + Some("a"), + Some("a"), + Some("a"), + Some("a"), + Some("b"), + Some("b"), + Some("b"), + ], + vec![ + Some("c"), + Some("c"), + Some("c"), + Some("c"), + Some("d"), + Some("d"), + Some("d"), + ], + ], + time_col: vec![ + None, + None, + Some(1_000), + Some(1_100), + None, + Some(1_000), + Some(1_100), + ], + timezone: None, + agg_cols: vec![ + vec![ + Some(1), + None, + Some(10), + Some(11), + Some(2), + Some(20), + Some(21), + ], + vec![ + Some(3), + Some(3), + Some(30), + None, + Some(4), + Some(40), + Some(41), + ], + ], + struct_cols: vec![], + input_batch_size, + }; + let params = get_params_ms(&records, 25, Some(975), 1_125); + let tc = TestCase { + test_records: records, + output_batch_size, + params, + }; + let batches = tc.run().unwrap(); + let actual = batches_to_lines(&batches); + insta::assert_yaml_snapshot!(actual, @r###" + --- + - +----+----+--------------------------+----+----+ + - "| g0 | g1 | time | a0 | a1 |" + - +----+----+--------------------------+----+----+ + - "| a | c | | 1 | 3 |" + - "| a | c | | | 3 |" + - "| a | c | 1970-01-01T00:00:00.975Z | | |" + - "| a | c | 1970-01-01T00:00:01Z | 10 | 30 |" + - "| a | c | 1970-01-01T00:00:01.025Z | | |" + - "| a | c | 1970-01-01T00:00:01.050Z | | |" + - "| a | c | 1970-01-01T00:00:01.075Z | | |" + - "| a | c | 1970-01-01T00:00:01.100Z | 11 | |" + - "| a | c | 1970-01-01T00:00:01.125Z | | |" + - "| b | d | | 2 | 4 |" + - "| b | d | 1970-01-01T00:00:00.975Z | | |" + - "| b | d | 1970-01-01T00:00:01Z | 20 | 40 |" + - "| b | d | 1970-01-01T00:00:01.025Z | | |" + - "| b | d | 1970-01-01T00:00:01.050Z | | |" + - "| b | d | 1970-01-01T00:00:01.075Z | | |" + - "| b | d | 1970-01-01T00:00:01.100Z | 21 | 41 |" + - "| b | d | 1970-01-01T00:00:01.125Z | | |" + - +----+----+--------------------------+----+----+ + "###); + assert_batch_count(&batches, output_batch_size); + } + }} +} + +#[test] +fn test_gapfill_simple_no_lower_bound() { + test_helpers::maybe_start_logging(); + insta::allow_duplicates! { for output_batch_size in [1, 2, 4, 8] { + for input_batch_size in [1, 2, 4] { + let batch = TestRecords { + group_cols: vec![vec![Some("a"), Some("a"), Some("b"), Some("b")]], + time_col: vec![Some(1_025), Some(1_100), Some(1_050), Some(1_100)], + timezone: None, + agg_cols: vec![vec![Some(10), Some(11), Some(20), Some(21)]], + struct_cols: vec![], + input_batch_size, + }; + let params = get_params_ms(&batch, 25, None, 1_125); + let tc = TestCase { + test_records: batch, + output_batch_size, + params, + }; + let batches = tc.run().unwrap(); + let actual = batches_to_lines(&batches); + insta::assert_yaml_snapshot!(actual, @r###" + --- + - +----+--------------------------+----+ + - "| g0 | time | a0 |" + - +----+--------------------------+----+ + - "| a | 1970-01-01T00:00:01.025Z | 10 |" + - "| a | 1970-01-01T00:00:01.050Z | |" + - "| a | 1970-01-01T00:00:01.075Z | |" + - "| a | 1970-01-01T00:00:01.100Z | 11 |" + - "| a | 1970-01-01T00:00:01.125Z | |" + - "| b | 1970-01-01T00:00:01.050Z | 20 |" + - "| b | 1970-01-01T00:00:01.075Z | |" + - "| b | 1970-01-01T00:00:01.100Z | 21 |" + - "| b | 1970-01-01T00:00:01.125Z | |" + - +----+--------------------------+----+ + "###); + assert_batch_count(&batches, output_batch_size); + } + }} +} + +#[test] +fn test_gapfill_fill_prev() { + test_helpers::maybe_start_logging(); + insta::allow_duplicates! { for output_batch_size in [1, 2, 4, 8] { + for input_batch_size in [1, 2, 4] { + let records = TestRecords { + group_cols: vec![vec![ + Some("a"), + Some("a"), + Some("b"), + Some("b"), + Some("b"), + ]], + time_col: vec![ + // 975 + Some(1000), + // 1025 + // 1050 + Some(1075), + // 1100 + // 1125 + // --- new series + // 975 + Some(1000), + // 1025 + Some(1050), + // 1075 + Some(1100), + // 1125 + ], + timezone: None, + agg_cols: vec![vec![ + Some(10), + Some(11), + Some(20), + None, + Some(21), + ]], + struct_cols: vec![], + input_batch_size, + }; + let params = get_params_ms_with_fill_strategy(&records, 25, Some(975), 1_125, FillStrategy::PrevNullAsIntentional); + let tc = TestCase { + test_records: records, + output_batch_size, + params, + }; + let batches = tc.run().unwrap(); + let actual = batches_to_lines(&batches); + insta::with_settings!({ + description => format!("input_batch_size: {input_batch_size}, output_batch_size: {output_batch_size}"), + }, { + insta::assert_yaml_snapshot!(actual, @r###" + --- + - +----+--------------------------+----+ + - "| g0 | time | a0 |" + - +----+--------------------------+----+ + - "| a | 1970-01-01T00:00:00.975Z | |" + - "| a | 1970-01-01T00:00:01Z | 10 |" + - "| a | 1970-01-01T00:00:01.025Z | 10 |" + - "| a | 1970-01-01T00:00:01.050Z | 10 |" + - "| a | 1970-01-01T00:00:01.075Z | 11 |" + - "| a | 1970-01-01T00:00:01.100Z | 11 |" + - "| a | 1970-01-01T00:00:01.125Z | 11 |" + - "| b | 1970-01-01T00:00:00.975Z | |" + - "| b | 1970-01-01T00:00:01Z | 20 |" + - "| b | 1970-01-01T00:00:01.025Z | 20 |" + - "| b | 1970-01-01T00:00:01.050Z | |" + - "| b | 1970-01-01T00:00:01.075Z | |" + - "| b | 1970-01-01T00:00:01.100Z | 21 |" + - "| b | 1970-01-01T00:00:01.125Z | 21 |" + - +----+--------------------------+----+ + "###) + }); + assert_batch_count(&batches, output_batch_size); + } + }} +} + +#[test] +fn test_gapfill_fill_prev_null_as_missing() { + test_helpers::maybe_start_logging(); + insta::allow_duplicates! { + for output_batch_size in [16, 1] { + for input_batch_size in [8, 1] { + let records = TestRecords { + group_cols: vec![vec![ + Some("a"), + Some("a"), + Some("b"), + Some("b"), + Some("b"), + ]], + time_col: vec![ + // 975 + Some(1000), + // 1025 + // 1050 + Some(1075), + // 1100 + // 1125 + // --- new series + // 975 + Some(1000), + // 1025 + Some(1050), + // 1075 + Some(1100), + // 1125 + ], + timezone: None, + agg_cols: vec![vec![ + Some(10), // a: 1000 + None, // a: 1075 + Some(20), // b: 1000 + None, // b: 1050 + Some(21), // b: 1100 + ]], + struct_cols: vec![], + input_batch_size, + }; + let params = get_params_ms_with_fill_strategy(&records, 25, Some(975), 1_125, FillStrategy::PrevNullAsMissing); + let tc = TestCase { + test_records: records, + output_batch_size, + params, + }; + let batches = tc.run().unwrap(); + let actual = batches_to_lines(&batches); + insta::with_settings!({ + description => format!("input_batch_size: {input_batch_size}, output_batch_size: {output_batch_size}"), + }, { + insta::assert_yaml_snapshot!(actual, @r###" + --- + - +----+--------------------------+----+ + - "| g0 | time | a0 |" + - +----+--------------------------+----+ + - "| a | 1970-01-01T00:00:00.975Z | |" + - "| a | 1970-01-01T00:00:01Z | 10 |" + - "| a | 1970-01-01T00:00:01.025Z | 10 |" + - "| a | 1970-01-01T00:00:01.050Z | 10 |" + - "| a | 1970-01-01T00:00:01.075Z | 10 |" + - "| a | 1970-01-01T00:00:01.100Z | 10 |" + - "| a | 1970-01-01T00:00:01.125Z | 10 |" + - "| b | 1970-01-01T00:00:00.975Z | |" + - "| b | 1970-01-01T00:00:01Z | 20 |" + - "| b | 1970-01-01T00:00:01.025Z | 20 |" + - "| b | 1970-01-01T00:00:01.050Z | 20 |" + - "| b | 1970-01-01T00:00:01.075Z | 20 |" + - "| b | 1970-01-01T00:00:01.100Z | 21 |" + - "| b | 1970-01-01T00:00:01.125Z | 21 |" + - +----+--------------------------+----+ + "###) + }); + assert_batch_count(&batches, output_batch_size); + } + }} +} + +#[test] +fn test_gapfill_fill_prev_null_as_missing_many_nulls() { + test_helpers::maybe_start_logging(); + insta::allow_duplicates! { + for output_batch_size in [16, 1] { + for input_batch_size in [8, 1] { + let records = TestRecords { + group_cols: vec![vec![ + Some("a"), + Some("a"), + Some("a"), + Some("a"), + Some("a"), + Some("a"), + // --- new series + Some("b"), + Some("b"), + Some("b"), + Some("b"), + Some("b"), + ]], + time_col: vec![ + None, + Some(975), + Some(1000), + Some(1025), + Some(1050), + // 1075 + Some(1100), + // 1125 + // --- new series + None, + Some(975), + // 1000 + Some(1025), + Some(1050), + // 1075 + Some(1100), + // 1125 + ], + timezone: None, + agg_cols: vec![vec![ + Some(-1), // a: null ts + Some(10), // a: 975 + None, // a: 1000 + None, // a: 1025 (stashed) + None, // a: 1050 (stashed) + // a: 1075 (stashed) + Some(12), // a: 1100 + // a: 1125 + // --- new series + Some(-2), // b: null ts + None, // b: 975 + // b: 1000 + Some(21), // b: 1025 + None, // b: 1050 + // b: 1075 + Some(22), // b: 1100 + // b: 1125 + ]], + struct_cols: vec![], + input_batch_size, + }; + let params = get_params_ms_with_fill_strategy(&records, 25, Some(975), 1_125, FillStrategy::PrevNullAsMissing); + let tc = TestCase { + test_records: records, + output_batch_size, + params, + }; + let batches = tc.run().unwrap(); + let actual = batches_to_lines(&batches); + insta::with_settings!({ + description => format!("input_batch_size: {input_batch_size}, output_batch_size: {output_batch_size}"), + }, { + insta::assert_yaml_snapshot!(actual, @r###" + --- + - +----+--------------------------+----+ + - "| g0 | time | a0 |" + - +----+--------------------------+----+ + - "| a | | -1 |" + - "| a | 1970-01-01T00:00:00.975Z | 10 |" + - "| a | 1970-01-01T00:00:01Z | 10 |" + - "| a | 1970-01-01T00:00:01.025Z | 10 |" + - "| a | 1970-01-01T00:00:01.050Z | 10 |" + - "| a | 1970-01-01T00:00:01.075Z | 10 |" + - "| a | 1970-01-01T00:00:01.100Z | 12 |" + - "| a | 1970-01-01T00:00:01.125Z | 12 |" + - "| b | | -2 |" + - "| b | 1970-01-01T00:00:00.975Z | |" + - "| b | 1970-01-01T00:00:01Z | |" + - "| b | 1970-01-01T00:00:01.025Z | 21 |" + - "| b | 1970-01-01T00:00:01.050Z | 21 |" + - "| b | 1970-01-01T00:00:01.075Z | 21 |" + - "| b | 1970-01-01T00:00:01.100Z | 22 |" + - "| b | 1970-01-01T00:00:01.125Z | 22 |" + - +----+--------------------------+----+ + "###) + }); + assert_batch_count(&batches, output_batch_size); + } + }} +} + +/// Show that: +/// - we can have multiple interpolated segments within +/// a series +/// - a null value will break interpolation +/// - times before the first or after the last non-null data point +/// in a series are filled with nulls. +#[test] +fn test_gapfill_fill_interpolate() { + test_helpers::maybe_start_logging(); + insta::allow_duplicates! { + for output_batch_size in [16, 1] { + let input_batch_size = 8; + let records = TestRecords { + group_cols: vec![vec![ + Some("a"), + Some("a"), + Some("a"), + // --- new series + Some("b"), + Some("b"), + Some("b"), + Some("b"), + Some("b"), + Some("b"), + ]], + time_col: vec![ + None, + // 975 + Some(1000), + // 1025 + // 1050 + Some(1075), + // 1100 + // 1125 + // --- new series + None, + Some(975), + Some(1000), + Some(1025), + // 1050 + Some(1075), + // 1100 + Some(1125), + ], + timezone: None, + agg_cols: vec![vec![ + Some(-1), + // null, 975 + Some(100), // 1000 + // 200 1025 + // 300 1050 + Some(400), // 1075 + // 1100 + // 1125 + // --- new series + Some(-10), + Some(1100), // 975 + None, // 1200 1000 (this null value will be filled) + Some(1300), // 1025 + // 1325 1050 + Some(1350), // 1075 + Some(1550), // 1100 + // 1125 + ]], + struct_cols: vec![], + input_batch_size, + }; + let params = get_params_ms_with_fill_strategy( + &records, + 25, + Some(975), + 1_125, + FillStrategy::LinearInterpolate + ); + let tc = TestCase { + test_records: records, + output_batch_size, + params, + }; + let batches = tc.run().unwrap(); + let actual = batches_to_lines(&batches); + insta::with_settings!({ + description => format!("input_batch_size: {input_batch_size}, output_batch_size: {output_batch_size}"), + }, { + insta::assert_yaml_snapshot!(actual, @r###" + --- + - +----+--------------------------+------+ + - "| g0 | time | a0 |" + - +----+--------------------------+------+ + - "| a | | -1 |" + - "| a | 1970-01-01T00:00:00.975Z | |" + - "| a | 1970-01-01T00:00:01Z | 100 |" + - "| a | 1970-01-01T00:00:01.025Z | 200 |" + - "| a | 1970-01-01T00:00:01.050Z | 300 |" + - "| a | 1970-01-01T00:00:01.075Z | 400 |" + - "| a | 1970-01-01T00:00:01.100Z | |" + - "| a | 1970-01-01T00:00:01.125Z | |" + - "| b | | -10 |" + - "| b | 1970-01-01T00:00:00.975Z | 1100 |" + - "| b | 1970-01-01T00:00:01Z | 1200 |" + - "| b | 1970-01-01T00:00:01.025Z | 1300 |" + - "| b | 1970-01-01T00:00:01.050Z | 1325 |" + - "| b | 1970-01-01T00:00:01.075Z | 1350 |" + - "| b | 1970-01-01T00:00:01.100Z | 1450 |" + - "| b | 1970-01-01T00:00:01.125Z | 1550 |" + - +----+--------------------------+------+ + "###) + }); + assert_batch_count(&batches, output_batch_size); + } + } +} + +#[test] +fn test_gapfill_simple_no_lower_bound_with_nulls() { + test_helpers::maybe_start_logging(); + insta::allow_duplicates! { for output_batch_size in [1, 2, 4, 8] { + for input_batch_size in [1, 2, 4] { + let batch = TestRecords { + group_cols: vec![vec![ + Some("a"), + Some("a"), + Some("a"), + Some("b"), + Some("b"), + Some("b"), + Some("b"), + Some("c"), + Some("c"), + Some("c"), + Some("c"), + Some("c"), + ]], + time_col: vec![ + None, // group a + Some(1_025), + Some(1_100), + None, // group b + None, + None, + None, // group c + None, + None, + None, + Some(1_050), + Some(1_100), + ], + timezone: None, + agg_cols: vec![vec![ + Some(1), // group a + Some(10), + Some(11), + Some(90), // group b + Some(91), + Some(92), + Some(93), + None, // group c + None, + Some(2), + Some(20), + Some(21), + ]], + struct_cols: vec![], + input_batch_size, + }; + let params = get_params_ms(&batch, 25, None, 1_125); + let tc = TestCase { + test_records: batch, + output_batch_size, + params, + }; + let batches = tc.run().unwrap(); + let actual = batches_to_lines(&batches); + insta::assert_yaml_snapshot!(actual, @r###" + --- + - +----+--------------------------+----+ + - "| g0 | time | a0 |" + - +----+--------------------------+----+ + - "| a | | 1 |" + - "| a | 1970-01-01T00:00:01.025Z | 10 |" + - "| a | 1970-01-01T00:00:01.050Z | |" + - "| a | 1970-01-01T00:00:01.075Z | |" + - "| a | 1970-01-01T00:00:01.100Z | 11 |" + - "| a | 1970-01-01T00:00:01.125Z | |" + - "| b | | 90 |" + - "| b | | 91 |" + - "| b | | 92 |" + - "| b | | 93 |" + - "| c | | |" + - "| c | | |" + - "| c | | 2 |" + - "| c | 1970-01-01T00:00:01.050Z | 20 |" + - "| c | 1970-01-01T00:00:01.075Z | |" + - "| c | 1970-01-01T00:00:01.100Z | 21 |" + - "| c | 1970-01-01T00:00:01.125Z | |" + - +----+--------------------------+----+ + "###); + assert_batch_count(&batches, output_batch_size); + } + }} +} + +#[test] +fn test_gapfill_oom() { + // Show that a graceful error is produced if memory limit is exceeded + test_helpers::maybe_start_logging(); + let input_batch_size = 128; + let output_batch_size = 128; + let batch = TestRecords { + group_cols: vec![vec![Some("a"), Some("a")]], + time_col: vec![Some(1_000), Some(1_100)], + timezone: None, + agg_cols: vec![vec![Some(10), Some(11)]], + struct_cols: vec![], + input_batch_size, + }; + let params = get_params_ms(&batch, 25, Some(975), 1_125); + let tc = TestCase { + test_records: batch, + output_batch_size, + params, + }; + let result = tc.run_with_memory_limit(1); + assert_error!(result, DataFusionError::ResourcesExhausted(_)); +} + +#[test] +fn test_gapfill_interpolate_struct() { + test_helpers::maybe_start_logging(); + insta::allow_duplicates! { + for output_batch_size in [16, 1] { + let input_batch_size = 8; + let records = TestRecords { + group_cols: vec![vec![ + Some("a"), + Some("a"), + Some("a"), + // --- new series + Some("b"), + Some("b"), + Some("b"), + Some("b"), + Some("b"), + Some("b"), + ]], + time_col: vec![ + None, + // 975 + Some(1000), + // 1025 + // 1050 + Some(1075), + // 1100 + // 1125 + // --- new series + None, + Some(975), + Some(1000), + Some(1025), + // 1050 + Some(1075), + // 1100 + Some(1125), + ], + timezone: None, + agg_cols: vec![], + struct_cols: vec![vec![ + Some(vec![-1, 0]), + // null, 975 + Some(vec![100, 0]), + // 200 1025 + // 300 1050 + Some(vec![400, 0]), // 1075 + // 1100 + // 1125 + // --- new series + Some(vec![-10, 0]), + Some(vec![1100, 0]), // 975 + None, // 1200 1000 (this null value will be filled) + Some(vec![1300, 0]), // 1025 + // 1325 1050 + Some(vec![1350, 0]), // 1075 + Some(vec![1550, 0]), // 1100 + // 1125 + ]], + input_batch_size, + }; + let params = get_params_ms_with_fill_strategy( + &records, + 25, + Some(975), + 1_125, + FillStrategy::LinearInterpolate + ); + let tc = TestCase { + test_records: records, + output_batch_size, + params, + }; + let batches = tc.run().unwrap(); + let actual = batches_to_lines(&batches); + insta::with_settings!({ + description => format!("input_batch_size: {input_batch_size}, output_batch_size: {output_batch_size}"), + }, { + insta::assert_yaml_snapshot!(actual, @r###" + --- + - +----+--------------------------+------------------------+ + - "| g0 | time | a0 |" + - +----+--------------------------+------------------------+ + - "| a | | {value: -1, time: 0} |" + - "| a | 1970-01-01T00:00:00.975Z | {value: , time: } |" + - "| a | 1970-01-01T00:00:01Z | {value: 100, time: 0} |" + - "| a | 1970-01-01T00:00:01.025Z | {value: 200, time: } |" + - "| a | 1970-01-01T00:00:01.050Z | {value: 300, time: } |" + - "| a | 1970-01-01T00:00:01.075Z | {value: 400, time: 0} |" + - "| a | 1970-01-01T00:00:01.100Z | {value: , time: } |" + - "| a | 1970-01-01T00:00:01.125Z | {value: , time: } |" + - "| b | | {value: -10, time: 0} |" + - "| b | 1970-01-01T00:00:00.975Z | {value: 1100, time: 0} |" + - "| b | 1970-01-01T00:00:01Z | {value: 1200, time: } |" + - "| b | 1970-01-01T00:00:01.025Z | {value: 1300, time: 0} |" + - "| b | 1970-01-01T00:00:01.050Z | {value: 1325, time: } |" + - "| b | 1970-01-01T00:00:01.075Z | {value: 1350, time: 0} |" + - "| b | 1970-01-01T00:00:01.100Z | {value: 1450, time: } |" + - "| b | 1970-01-01T00:00:01.125Z | {value: 1550, time: 0} |" + - +----+--------------------------+------------------------+ + "###) + }); + assert_batch_count(&batches, output_batch_size); + } + } +} + +#[test] +fn test_gapfill_interpolate_struct_additional_data() { + test_helpers::maybe_start_logging(); + insta::allow_duplicates! { + for output_batch_size in [16, 1] { + let input_batch_size = 8; + let records = TestRecords { + group_cols: vec![vec![ + Some("a"), + Some("a"), + Some("a"), + // --- new series + Some("b"), + Some("b"), + Some("b"), + Some("b"), + Some("b"), + Some("b"), + ]], + time_col: vec![ + None, + // 975 + Some(1000), + // 1025 + // 1050 + Some(1075), + // 1100 + // 1125 + // --- new series + None, + Some(975), + Some(1000), + Some(1025), + // 1050 + Some(1075), + // 1100 + Some(1125), + ], + timezone: None, + agg_cols: vec![], + struct_cols: vec![vec![ + Some(vec![-1, 0, 1, 1]), + // null, 975 + Some(vec![100, 0, 2, 2]), + // 200 1025 + // 300 1050 + Some(vec![400, 0, 3, 3]), // 1075 + // 1100 + // 1125 + // --- new series + Some(vec![-10, 0, 10, 10]), + Some(vec![1100, 0, 11, 11]), // 975 + None, // 1200 1000 (this null value will be filled) + Some(vec![1300, 0, 12, 12]), // 1025 + // 1325 1050 + Some(vec![1350, 0, 13, 13]), // 1075 + Some(vec![1550, 0, 14, 14]), // 1100 + // 1125 + ]], + input_batch_size, + }; + let params = get_params_ms_with_fill_strategy( + &records, + 25, + Some(975), + 1_125, + FillStrategy::LinearInterpolate + ); + let tc = TestCase { + test_records: records, + output_batch_size, + params, + }; + let batches = tc.run().unwrap(); + let actual = batches_to_lines(&batches); + insta::with_settings!({ + description => format!("input_batch_size: {input_batch_size}, output_batch_size: {output_batch_size}"), + }, { + insta::assert_yaml_snapshot!(actual, @r###" + --- + - +----+--------------------------+--------------------------------------------------+ + - "| g0 | time | a0 |" + - +----+--------------------------+--------------------------------------------------+ + - "| a | | {value: -1, time: 0, other_0: 1, other_1: 1} |" + - "| a | 1970-01-01T00:00:00.975Z | {value: , time: , other_0: , other_1: } |" + - "| a | 1970-01-01T00:00:01Z | {value: 100, time: 0, other_0: 2, other_1: 2} |" + - "| a | 1970-01-01T00:00:01.025Z | {value: 200, time: , other_0: , other_1: } |" + - "| a | 1970-01-01T00:00:01.050Z | {value: 300, time: , other_0: , other_1: } |" + - "| a | 1970-01-01T00:00:01.075Z | {value: 400, time: 0, other_0: 3, other_1: 3} |" + - "| a | 1970-01-01T00:00:01.100Z | {value: , time: , other_0: , other_1: } |" + - "| a | 1970-01-01T00:00:01.125Z | {value: , time: , other_0: , other_1: } |" + - "| b | | {value: -10, time: 0, other_0: 10, other_1: 10} |" + - "| b | 1970-01-01T00:00:00.975Z | {value: 1100, time: 0, other_0: 11, other_1: 11} |" + - "| b | 1970-01-01T00:00:01Z | {value: 1200, time: , other_0: , other_1: } |" + - "| b | 1970-01-01T00:00:01.025Z | {value: 1300, time: 0, other_0: 12, other_1: 12} |" + - "| b | 1970-01-01T00:00:01.050Z | {value: 1325, time: , other_0: , other_1: } |" + - "| b | 1970-01-01T00:00:01.075Z | {value: 1350, time: 0, other_0: 13, other_1: 13} |" + - "| b | 1970-01-01T00:00:01.100Z | {value: 1450, time: , other_0: , other_1: } |" + - "| b | 1970-01-01T00:00:01.125Z | {value: 1550, time: 0, other_0: 14, other_1: 14} |" + - +----+--------------------------+--------------------------------------------------+ + "###) + }); + assert_batch_count(&batches, output_batch_size); + } + } +} + +fn assert_batch_count(actual_batches: &[RecordBatch], batch_size: usize) { + let num_rows = actual_batches.iter().map(|b| b.num_rows()).sum::(); + let expected_batch_count = f64::ceil(num_rows as f64 / batch_size as f64) as usize; + assert_eq!(expected_batch_count, actual_batches.len()); +} + +type ExprVec = Vec>; + +pub(super) struct TestRecords { + pub group_cols: Vec>>, + // Stored as millisecods since intervals use millis, + // to let test cases be consistent and easier to read. + pub time_col: Vec>, + pub timezone: Option>, + pub agg_cols: Vec>>, + pub struct_cols: Vec>>>, + pub input_batch_size: usize, +} + +impl TestRecords { + fn schema(&self) -> SchemaRef { + // In order to test input with null timestamps, we need the + // timestamp column to be nullable. Unforunately this means + // we can't use the IOx schema builder here. + let mut fields = vec![]; + for i in 0..self.group_cols.len() { + fields.push(Field::new( + format!("g{i}"), + (&InfluxColumnType::Tag).into(), + true, + )); + } + fields.push(Field::new( + "time", + DataType::Timestamp(TimeUnit::Nanosecond, self.timezone.clone()), + true, + )); + for i in 0..self.agg_cols.len() { + fields.push(Field::new( + format!("a{i}"), + (&InfluxColumnType::Field(InfluxFieldType::Integer)).into(), + true, + )); + } + for i in 0..self.struct_cols.len() { + fields.push(Field::new( + format!("a{}", self.agg_cols.len() + i), + DataType::Struct(self.struct_fields(i)), + true, + )); + } + Schema::new(fields).into() + } + + fn struct_fields(&self, col: usize) -> Fields { + let mut fields = vec![ + Field::new( + "value", + (&InfluxColumnType::Field(InfluxFieldType::Integer)).into(), + true, + ), + Field::new( + "time", + (&InfluxColumnType::Field(InfluxFieldType::Integer)).into(), + true, + ), + ]; + let num_other = self.struct_cols[col] + .iter() + .find(|o| o.is_some()) + .map_or(0, |v| match v.as_ref().unwrap().len() { + 0..=2 => 0, + n => n - 2, + }); + for i in 0..num_other { + fields.push(Field::new( + format!("other_{}", i), + (&InfluxColumnType::Field(InfluxFieldType::Integer)).into(), + true, + )); + } + fields.into() + } + + fn len(&self) -> usize { + self.time_col.len() + } + + fn exprs(&self) -> Result<(ExprVec, ExprVec)> { + let mut group_expr: ExprVec = vec![]; + let mut aggr_expr: ExprVec = vec![]; + let ngroup_cols = self.group_cols.len(); + for i in 0..self.schema().fields().len() { + match i.cmp(&ngroup_cols) { + Ordering::Less => group_expr.push(Arc::new(Column::new(&format!("g{i}"), i))), + Ordering::Equal => group_expr.push(Arc::new(Column::new("t", i))), + Ordering::Greater => { + let idx = i - ngroup_cols + 1; + aggr_expr.push(Arc::new(Column::new(&format!("a{idx}"), i))); + } + } + } + Ok((group_expr, aggr_expr)) + } +} + +impl TryFrom for Vec { + type Error = DataFusionError; + + fn try_from(value: TestRecords) -> Result { + let mut arrs: Vec = Vec::with_capacity( + value.group_cols.len() + value.agg_cols.len() + value.struct_cols.len() + 1, + ); + for gc in &value.group_cols { + let arr = Arc::new(DictionaryArray::::from_iter(gc.iter().cloned())); + arrs.push(arr); + } + // Scale from milliseconds to the nanoseconds that are actually stored. + let scaled_times = value + .time_col + .iter() + .map(|o| o.map(|v| v * 1_000_000)) + .collect::() + .with_timezone_opt(value.timezone.clone()); + arrs.push(Arc::new(scaled_times)); + for ac in &value.agg_cols { + let arr = Arc::new(Int64Array::from_iter(ac)); + arrs.push(arr); + } + for i in 0..value.struct_cols.len() { + let fields = value.struct_fields(i); + let nulls = value.struct_cols[i] + .iter() + .map(|o| o.is_none()) + .collect::>(); + let mut struct_arrs: Vec = vec![]; + for j in 0..fields.len() { + let arr = Arc::new(Int64Array::from_iter( + value.struct_cols[i] + .iter() + .map(|o| o.as_ref().map(|v| v[j])), + )); + struct_arrs.push(arr); + } + arrs.push(Arc::new(StructArray::new( + fields, + struct_arrs, + Some(nulls.into()), + ))); + } + + let one_batch = RecordBatch::try_new(value.schema(), arrs) + .map_err(|err| DataFusionError::ArrowError(err, None))?; + let mut batches = vec![]; + let mut offset = 0; + while offset < one_batch.num_rows() { + let len = std::cmp::min(value.input_batch_size, one_batch.num_rows() - offset); + let batch = one_batch.slice(offset, len); + batches.push(batch); + offset += value.input_batch_size; + } + Ok(batches) + } +} + +struct TestCase { + test_records: TestRecords, + output_batch_size: usize, + params: GapFillExecParams, +} + +impl TestCase { + fn run(self) -> Result> { + block_on(async { + let session_ctx = SessionContext::new_with_config( + SessionConfig::default().with_batch_size(self.output_batch_size), + ) + .into(); + Self::execute_with_config(&session_ctx, self.plan()?).await + }) + } + + fn run_with_memory_limit(self, limit: usize) -> Result> { + block_on(async { + let session_ctx = SessionContext::new_with_config_rt( + SessionConfig::default().with_batch_size(self.output_batch_size), + RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(limit, 1.0))?.into(), + ) + .into(); + let result = Self::execute_with_config(&session_ctx, self.plan()?).await; + + if result.is_ok() { + // Verify that the operator reports usage in a + // symmetrical way. + let pool = &session_ctx.runtime_env().memory_pool; + assert_eq!(0, pool.reserved()); + } + + result + }) + } + + fn plan(self) -> Result> { + let schema = self.test_records.schema(); + let (group_expr, aggr_expr) = self.test_records.exprs()?; + + let input_batch_size = self.test_records.input_batch_size; + + let num_records = self.test_records.len(); + let batches: Vec = self.test_records.try_into()?; + assert_batch_count(&batches, input_batch_size); + assert_eq!( + batches.iter().map(|b| b.num_rows()).sum::(), + num_records + ); + + debug!( + "input_batch_size is {input_batch_size}, output_batch_size is {}", + self.output_batch_size + ); + let input = Arc::new(MemoryExec::try_new(&[batches], schema, None)?); + let plan = Arc::new(GapFillExec::try_new( + input, + group_expr, + aggr_expr, + self.params.clone(), + )?); + Ok(plan) + } + + async fn execute_with_config( + session_ctx: &Arc, + plan: Arc, + ) -> Result> { + let task_ctx = Arc::new(TaskContext::from(session_ctx.as_ref())); + collect(plan, task_ctx).await + } +} + +fn bound_included_from_option(o: Option) -> Bound { + if let Some(v) = o { + Bound::Included(v) + } else { + Bound::Unbounded + } +} + +fn phys_fill_strategies( + records: &TestRecords, + fill_strategy: FillStrategy, +) -> Result, FillStrategy)>> { + let start = records.group_cols.len() + 1; // 1 is for time col + let end = start + records.agg_cols.len() + records.struct_cols.len(); + let mut v = Vec::with_capacity(records.agg_cols.len()); + for f in &records.schema().fields()[start..end] { + v.push((phys_col(f.name(), &records.schema())?, fill_strategy)); + } + Ok(v) +} + +fn get_params_ms_with_fill_strategy( + batch: &TestRecords, + stride_ms: i64, + start: Option, + end: i64, + fill_strategy: FillStrategy, +) -> GapFillExecParams { + get_params_ms_with_origin_fill_strategy(batch, stride_ms, start, end, None, fill_strategy) +} + +fn get_params_ms_with_origin_fill_strategy( + batch: &TestRecords, + stride_ms: i64, + start: Option, + end: i64, + origin_ms: Option, + fill_strategy: FillStrategy, +) -> GapFillExecParams { + // stride is in ms + let stride = ScalarValue::new_interval_mdn(0, 0, stride_ms * 1_000_000); + let origin = + origin_ms.map(|o| phys_lit(ScalarValue::TimestampNanosecond(Some(o * 1_000_000), None))); + + GapFillExecParams { + stride: phys_lit(stride), + time_column: Column::new("t", batch.group_cols.len()), + origin, + // timestamps are nanos, so scale them accordingly + time_range: Range { + start: bound_included_from_option(start.map(|start| { + phys_lit(ScalarValue::TimestampNanosecond( + Some(start * 1_000_000), + None, + )) + })), + end: Bound::Included(phys_lit(ScalarValue::TimestampNanosecond( + Some(end * 1_000_000), + None, + ))), + }, + fill_strategy: phys_fill_strategies(batch, fill_strategy).unwrap(), + } +} + +fn get_params_ms( + batch: &TestRecords, + stride: i64, + start: Option, + end: i64, +) -> GapFillExecParams { + get_params_ms_with_fill_strategy(batch, stride, start, end, FillStrategy::Null) +} diff --git a/iox_query/src/exec/gapfill/mod.rs b/iox_query/src/exec/gapfill/mod.rs new file mode 100644 index 0000000..30ef8a5 --- /dev/null +++ b/iox_query/src/exec/gapfill/mod.rs @@ -0,0 +1,823 @@ +//! This module contains code that implements +//! a gap-filling extension to DataFusion + +mod algo; +mod buffered_input; +#[cfg(test)] +mod exec_tests; +mod params; +mod stream; + +use std::{ + fmt::{self, Debug}, + ops::{Bound, Range}, + sync::Arc, +}; + +use arrow::{compute::SortOptions, datatypes::SchemaRef}; +use datafusion::{ + common::DFSchemaRef, + error::{DataFusionError, Result}, + execution::{context::TaskContext, memory_pool::MemoryConsumer}, + logical_expr::{LogicalPlan, UserDefinedLogicalNodeCore}, + physical_expr::{ + create_physical_expr, execution_props::ExecutionProps, PhysicalSortExpr, + PhysicalSortRequirement, + }, + physical_plan::{ + expressions::Column, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet}, + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, + SendableRecordBatchStream, Statistics, + }, + prelude::Expr, +}; + +use self::stream::GapFillStream; + +/// A logical node that represents the gap filling operation. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct GapFill { + /// The incoming logical plan + pub input: Arc, + /// Grouping expressions + pub group_expr: Vec, + /// Aggregate expressions + pub aggr_expr: Vec, + /// Parameters to configure the behavior of the + /// gap-filling operation + pub params: GapFillParams, +} + +/// Parameters to the GapFill operation +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct GapFillParams { + /// The stride argument from the call to DATE_BIN_GAPFILL + pub stride: Expr, + /// The source time column + pub time_column: Expr, + /// The origin argument from the call to DATE_BIN_GAPFILL + pub origin: Option, + /// The time range of the time column inferred from predicates + /// in the overall query. The lower bound may be [`Bound::Unbounded`] + /// which implies that gap-filling should just start from the + /// first point in each series. + pub time_range: Range>, + /// What to do when filling aggregate columns. + /// The first item in the tuple will be the column + /// reference for the aggregate column. + pub fill_strategy: Vec<(Expr, FillStrategy)>, +} + +/// Describes how to fill gaps in an aggregate column. +#[derive(Clone, Debug, Hash, PartialEq, Eq, Copy)] +pub enum FillStrategy { + /// Fill with null values. + /// This is the InfluxQL behavior for `FILL(NULL)` or `FILL(NONE)`. + Null, + /// Fill with the most recent value in the input column. + /// Null values in the input are preserved. + #[allow(dead_code)] + PrevNullAsIntentional, + /// Fill with the most recent non-null value in the input column. + /// This is the InfluxQL behavior for `FILL(PREVIOUS)`. + PrevNullAsMissing, + /// Fill the gaps between points linearly. + /// Null values will not be considered as missing, so two non-null values + /// with a null in between will not be filled. + LinearInterpolate, +} + +impl GapFillParams { + // Extract the expressions so they can be optimized. + fn expressions(&self) -> Vec { + let mut exprs = vec![self.stride.clone(), self.time_column.clone()]; + if let Some(e) = self.origin.as_ref() { + exprs.push(e.clone()) + } + if let Some(start) = bound_extract(&self.time_range.start) { + exprs.push(start.clone()); + } + exprs.push( + bound_extract(&self.time_range.end) + .unwrap_or_else(|| panic!("upper time bound is required")) + .clone(), + ); + exprs + } + + #[allow(clippy::wrong_self_convention)] // follows convention of UserDefinedLogicalNode + fn from_template(&self, exprs: &[Expr], aggr_expr: &[Expr]) -> Self { + assert!( + exprs.len() >= 3, + "should be a at least stride, source and origin in params" + ); + let mut iter = exprs.iter().cloned(); + let stride = iter.next().unwrap(); + let time_column = iter.next().unwrap(); + let origin = self.origin.as_ref().map(|_| iter.next().unwrap()); + let time_range = try_map_range(&self.time_range, |b| { + try_map_bound(b.as_ref(), |_| { + Ok(iter.next().expect("expr count should match template")) + }) + }) + .unwrap(); + + let fill_strategy = aggr_expr + .iter() + .cloned() + .zip( + self.fill_strategy + .iter() + .map(|(_expr, fill_strategy)| fill_strategy) + .cloned(), + ) + .collect(); + + Self { + stride, + time_column, + origin, + time_range, + fill_strategy, + } + } + + // Find the expression that matches `e` and replace its fill strategy. + // If such an expression is found, return the old strategy, and `None` otherwise. + fn replace_fill_strategy(&mut self, e: &Expr, mut fs: FillStrategy) -> Option { + for expr_fs in &mut self.fill_strategy { + if &expr_fs.0 == e { + std::mem::swap(&mut fs, &mut expr_fs.1); + return Some(fs); + } + } + None + } +} + +impl GapFill { + /// Create a new gap-filling operator. + pub fn try_new( + input: Arc, + group_expr: Vec, + aggr_expr: Vec, + params: GapFillParams, + ) -> Result { + if params.time_range.end == Bound::Unbounded { + return Err(DataFusionError::Internal( + "missing upper bound in GapFill time range".to_string(), + )); + } + Ok(Self { + input, + group_expr, + aggr_expr, + params, + }) + } + + // Find the expression that matches `e` and replace its fill strategy. + // If such an expression is found, return the old strategy, and `None` otherwise. + pub(crate) fn replace_fill_strategy( + &mut self, + e: &Expr, + fs: FillStrategy, + ) -> Option { + self.params.replace_fill_strategy(e, fs) + } +} + +impl UserDefinedLogicalNodeCore for GapFill { + fn name(&self) -> &str { + "GapFill" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![self.input.as_ref()] + } + + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + self.group_expr + .iter() + .chain(&self.aggr_expr) + .chain(&self.params.expressions()) + .cloned() + .collect() + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let aggr_expr: String = self + .params + .fill_strategy + .iter() + .map(|(e, fs)| match fs { + FillStrategy::PrevNullAsIntentional => format!("LOCF(null-as-intentional, {})", e), + FillStrategy::PrevNullAsMissing => format!("LOCF({})", e), + FillStrategy::LinearInterpolate => format!("INTERPOLATE({})", e), + FillStrategy::Null => e.to_string(), + }) + .collect::>() + .join(", "); + + let group_expr = self + .group_expr + .iter() + .map(|e| e.to_string()) + .collect::>() + .join(", "); + + write!( + f, + "{}: groupBy=[{group_expr}], aggr=[[{aggr_expr}]], time_column={}, stride={}, range={:?}", + self.name(), + self.params.time_column, + self.params.stride, + self.params.time_range, + ) + } + + fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + let mut group_expr: Vec<_> = exprs.to_vec(); + let mut aggr_expr = group_expr.split_off(self.group_expr.len()); + let param_expr = aggr_expr.split_off(self.aggr_expr.len()); + let params = self.params.from_template(¶m_expr, &aggr_expr); + Self::try_new(Arc::new(inputs[0].clone()), group_expr, aggr_expr, params) + .expect("should not fail") + } +} + +/// Called by the extension planner to plan a [GapFill] node. +pub(crate) fn plan_gap_fill( + execution_props: &ExecutionProps, + gap_fill: &GapFill, + logical_inputs: &[&LogicalPlan], + physical_inputs: &[Arc], +) -> Result { + if logical_inputs.len() != 1 { + return Err(DataFusionError::Internal( + "GapFillExec: wrong number of logical inputs".to_string(), + )); + } + if physical_inputs.len() != 1 { + return Err(DataFusionError::Internal( + "GapFillExec: wrong number of physical inputs".to_string(), + )); + } + + let input_dfschema = logical_inputs[0].schema().as_ref(); + let input_schema = physical_inputs[0].schema(); + let input_schema = input_schema.as_ref(); + + let group_expr: Result> = gap_fill + .group_expr + .iter() + .map(|e| create_physical_expr(e, input_dfschema, input_schema, execution_props)) + .collect(); + let group_expr = group_expr?; + + let aggr_expr: Result> = gap_fill + .aggr_expr + .iter() + .map(|e| create_physical_expr(e, input_dfschema, input_schema, execution_props)) + .collect(); + let aggr_expr = aggr_expr?; + + let logical_time_column = gap_fill.params.time_column.try_into_col()?; + let time_column = Column::new_with_schema(&logical_time_column.name, input_schema)?; + + let stride = create_physical_expr( + &gap_fill.params.stride, + input_dfschema, + input_schema, + execution_props, + )?; + + let time_range = &gap_fill.params.time_range; + let time_range = try_map_range(time_range, |b| { + try_map_bound(b.as_ref(), |e| { + create_physical_expr(e, input_dfschema, input_schema, execution_props) + }) + })?; + + let origin = gap_fill + .params + .origin + .as_ref() + .map(|e| create_physical_expr(e, input_dfschema, input_schema, execution_props)) + .transpose()?; + + let fill_strategy = gap_fill + .params + .fill_strategy + .iter() + .map(|(e, fs)| { + Ok(( + create_physical_expr(e, input_dfschema, input_schema, execution_props)?, + *fs, + )) + }) + .collect::, FillStrategy)>>>()?; + + let params = GapFillExecParams { + stride, + time_column, + origin, + time_range, + fill_strategy, + }; + GapFillExec::try_new( + Arc::clone(&physical_inputs[0]), + group_expr, + aggr_expr, + params, + ) +} + +fn try_map_range(tr: &Range, mut f: F) -> Result> +where + F: FnMut(&T) -> Result, +{ + Ok(Range { + start: f(&tr.start)?, + end: f(&tr.end)?, + }) +} + +fn try_map_bound(bt: Bound, mut f: F) -> Result> +where + F: FnMut(T) -> Result, +{ + Ok(match bt { + Bound::Excluded(t) => Bound::Excluded(f(t)?), + Bound::Included(t) => Bound::Included(f(t)?), + Bound::Unbounded => Bound::Unbounded, + }) +} + +fn bound_extract(b: &Bound) -> Option<&T> { + match b { + Bound::Included(t) | Bound::Excluded(t) => Some(t), + Bound::Unbounded => None, + } +} + +/// A physical node for the gap-fill operation. +pub struct GapFillExec { + input: Arc, + // The group by expressions from the original aggregation node. + group_expr: Vec>, + // The aggregate expressions from the original aggregation node. + aggr_expr: Vec>, + // The sort expressions for the required sort order of the input: + // all of the group exressions, with the time column being last. + sort_expr: Vec, + // Parameters (besides streaming data) to gap filling + params: GapFillExecParams, + /// Metrics reporting behavior during execution. + metrics: ExecutionPlanMetricsSet, +} + +#[derive(Clone, Debug)] +struct GapFillExecParams { + /// The uniform interval of incoming timestamps + stride: Arc, + /// The timestamp column produced by date_bin + time_column: Column, + /// The origin argument from the all to DATE_BIN_GAPFILL + origin: Option>, + /// The time range of source input to DATE_BIN_GAPFILL. + /// Inferred from predicates in the overall query. + time_range: Range>>, + /// What to do when filling aggregate columns. + /// The 0th element in each tuple is the aggregate column. + fill_strategy: Vec<(Arc, FillStrategy)>, +} + +impl GapFillExec { + fn try_new( + input: Arc, + group_expr: Vec>, + aggr_expr: Vec>, + params: GapFillExecParams, + ) -> Result { + let sort_expr = { + let mut sort_expr: Vec<_> = group_expr + .iter() + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: SortOptions::default(), + }) + .collect(); + + // Ensure that the time column is the last component in the sort + // expressions. + let time_idx = group_expr + .iter() + .enumerate() + .find(|(_i, e)| { + e.as_any() + .downcast_ref::() + .map_or(false, |c| c.index() == params.time_column.index()) + }) + .map(|(i, _)| i); + + if let Some(time_idx) = time_idx { + let last_elem = sort_expr.len() - 1; + sort_expr.swap(time_idx, last_elem); + } else { + return Err(DataFusionError::Internal( + "could not find time column for GapFillExec".to_string(), + )); + } + + sort_expr + }; + + Ok(Self { + input, + group_expr, + aggr_expr, + sort_expr, + params, + metrics: ExecutionPlanMetricsSet::new(), + }) + } +} + +impl Debug for GapFillExec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "GapFillExec") + } +} + +impl ExecutionPlan for GapFillExec { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.input.schema() + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(1) + } + + fn required_input_distribution(&self) -> Vec { + // It seems like it could be possible to partition on all the + // group keys except for the time expression. For now, keep it simple. + vec![Distribution::SinglePartition] + } + + fn output_ordering(&self) -> Option<&[datafusion::physical_expr::PhysicalSortExpr]> { + self.input.output_ordering() + } + + fn required_input_ordering(&self) -> Vec>> { + vec![Some(PhysicalSortRequirement::from_sort_exprs( + &self.sort_expr, + ))] + } + + fn maintains_input_order(&self) -> Vec { + vec![true] + } + + fn children(&self) -> Vec> { + vec![Arc::clone(&self.input)] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + match children.len() { + 1 => Ok(Arc::new(Self::try_new( + Arc::clone(&children[0]), + self.group_expr.clone(), + self.aggr_expr.clone(), + self.params.clone(), + )?)), + _ => Err(DataFusionError::Internal( + "GapFillExec wrong number of children".to_string(), + )), + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + if partition != 0 { + return Err(DataFusionError::Internal(format!( + "GapFillExec invalid partition {partition}, there can be only one partition" + ))); + } + + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + let output_batch_size = context.session_config().batch_size(); + let reservation = MemoryConsumer::new(format!("GapFillExec[{partition}]")) + .register(context.memory_pool()); + let input_stream = self.input.execute(partition, context)?; + Ok(Box::pin(GapFillStream::try_new( + self, + output_batch_size, + input_stream, + reservation, + baseline_metrics, + )?)) + } + + fn statistics(&self) -> Result { + Ok(Statistics::new_unknown(&self.schema())) + } +} + +impl DisplayAs for GapFillExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let group_expr: Vec<_> = self.group_expr.iter().map(|e| e.to_string()).collect(); + let aggr_expr: Vec<_> = self + .params + .fill_strategy + .iter() + .map(|(e, fs)| match fs { + FillStrategy::PrevNullAsIntentional => { + format!("LOCF(null-as-intentional, {})", e) + } + FillStrategy::PrevNullAsMissing => format!("LOCF({})", e), + FillStrategy::LinearInterpolate => format!("INTERPOLATE({})", e), + FillStrategy::Null => e.to_string(), + }) + .collect(); + let time_range = try_map_range(&self.params.time_range, |b| { + try_map_bound(b.as_ref(), |e| Ok(e.to_string())) + }) + .map_err(|_| fmt::Error {})?; + write!( + f, + "GapFillExec: group_expr=[{}], aggr_expr=[{}], stride={}, time_range={:?}", + group_expr.join(", "), + aggr_expr.join(", "), + self.params.stride, + time_range + ) + } + } + } +} + +#[cfg(test)] +mod test { + use std::ops::{Bound, Range}; + + use crate::{ + exec::{Executor, ExecutorType}, + test::{format_execution_plan, format_logical_plan}, + }; + + use super::*; + use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; + use datafusion::{ + datasource::empty::EmptyTable, + error::Result, + logical_expr::{logical_plan, Extension, UserDefinedLogicalNode}, + prelude::{col, lit}, + scalar::ScalarValue, + }; + use datafusion_util::lit_timestamptz_nano; + + use test_helpers::assert_error; + + fn schema() -> Schema { + Schema::new(vec![ + Field::new( + "time", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new("loc", DataType::Utf8, false), + Field::new("temp", DataType::Float64, false), + ]) + } + + fn table_scan() -> Result { + let schema = schema(); + logical_plan::table_scan(Some("temps"), &schema, None)?.build() + } + + fn fill_strategy_null(cols: Vec) -> Vec<(Expr, FillStrategy)> { + cols.into_iter().map(|e| (e, FillStrategy::Null)).collect() + } + + #[test] + fn test_try_new_errs() { + let scan = table_scan().unwrap(); + let result = GapFill::try_new( + Arc::new(scan), + vec![col("loc"), col("time")], + vec![col("temp")], + GapFillParams { + stride: lit(ScalarValue::IntervalDayTime(Some(60_000))), + time_column: col("time"), + origin: None, + time_range: Range { + start: Bound::Included(lit_timestamptz_nano(1000)), + end: Bound::Unbounded, + }, + fill_strategy: fill_strategy_null(vec![col("temp")]), + }, + ); + + assert_error!(result, DataFusionError::Internal(ref msg) if msg == "missing upper bound in GapFill time range"); + } + + fn assert_gapfill_from_template_roundtrip(gapfill: &GapFill) { + let gapfill_as_node: &dyn UserDefinedLogicalNode = gapfill; + let scan = table_scan().unwrap(); + let exprs = gapfill_as_node.expressions(); + let want_exprs = gapfill.group_expr.len() + + gapfill.aggr_expr.len() + + 2 // stride, time + + gapfill.params.origin.iter().count() + + bound_extract(&gapfill.params.time_range.start).iter().count() + + bound_extract(&gapfill.params.time_range.end).iter().count(); + assert_eq!(want_exprs, exprs.len()); + let gapfill_ft = gapfill_as_node.from_template(&exprs, &[scan]); + let gapfill_ft = gapfill_ft + .as_any() + .downcast_ref::() + .expect("should be a GapFill"); + assert_eq!(gapfill.group_expr, gapfill_ft.group_expr); + assert_eq!(gapfill.aggr_expr, gapfill_ft.aggr_expr); + assert_eq!(gapfill.params, gapfill_ft.params); + } + + #[test] + fn test_from_template() { + for params in vec![ + // no origin, no start bound + GapFillParams { + stride: lit(ScalarValue::IntervalDayTime(Some(60_000))), + time_column: col("time"), + origin: None, + time_range: Range { + start: Bound::Unbounded, + end: Bound::Excluded(lit_timestamptz_nano(2000)), + }, + fill_strategy: fill_strategy_null(vec![col("temp")]), + }, + // no origin, yes start bound + GapFillParams { + stride: lit(ScalarValue::IntervalDayTime(Some(60_000))), + time_column: col("time"), + origin: None, + time_range: Range { + start: Bound::Included(lit_timestamptz_nano(1000)), + end: Bound::Excluded(lit_timestamptz_nano(2000)), + }, + fill_strategy: fill_strategy_null(vec![col("temp")]), + }, + // yes origin, no start bound + GapFillParams { + stride: lit(ScalarValue::IntervalDayTime(Some(60_000))), + time_column: col("time"), + origin: Some(lit_timestamptz_nano(1_000_000_000)), + time_range: Range { + start: Bound::Unbounded, + end: Bound::Excluded(lit_timestamptz_nano(2000)), + }, + fill_strategy: fill_strategy_null(vec![col("temp")]), + }, + // yes origin, yes start bound + GapFillParams { + stride: lit(ScalarValue::IntervalDayTime(Some(60_000))), + time_column: col("time"), + origin: Some(lit_timestamptz_nano(1_000_000_000)), + time_range: Range { + start: Bound::Included(lit_timestamptz_nano(1000)), + end: Bound::Excluded(lit_timestamptz_nano(2000)), + }, + fill_strategy: fill_strategy_null(vec![col("temp")]), + }, + ] { + let scan = table_scan().unwrap(); + let gapfill = GapFill::try_new( + Arc::new(scan.clone()), + vec![col("loc"), col("time")], + vec![col("temp")], + params, + ) + .unwrap(); + assert_gapfill_from_template_roundtrip(&gapfill); + } + } + + #[test] + fn fmt_logical_plan() -> Result<()> { + // This test case does not make much sense but + // just verifies we can construct a logical gapfill node + // and show its plan. + let scan = table_scan()?; + let gapfill = GapFill::try_new( + Arc::new(scan), + vec![col("loc"), col("time")], + vec![col("temp")], + GapFillParams { + stride: lit(ScalarValue::IntervalDayTime(Some(60_000))), + time_column: col("time"), + origin: None, + time_range: Range { + start: Bound::Included(lit_timestamptz_nano(1000)), + end: Bound::Excluded(lit_timestamptz_nano(2000)), + }, + fill_strategy: fill_strategy_null(vec![col("temp")]), + }, + )?; + let plan = LogicalPlan::Extension(Extension { + node: Arc::new(gapfill), + }); + + insta::assert_yaml_snapshot!( + format_logical_plan(&plan), + @r###" + --- + - " GapFill: groupBy=[loc, time], aggr=[[temp]], time_column=time, stride=IntervalDayTime(\"60000\"), range=Included(Literal(TimestampNanosecond(1000, None)))..Excluded(Literal(TimestampNanosecond(2000, None)))" + - " TableScan: temps" + "### + ); + Ok(()) + } + + async fn format_explain(sql: &str) -> Result> { + let executor = Executor::new_testing(); + let context = executor.new_context(ExecutorType::Query); + context + .inner() + .register_table("temps", Arc::new(EmptyTable::new(Arc::new(schema()))))?; + let physical_plan = context.sql_to_physical_plan(sql).await?; + Ok(format_execution_plan(&physical_plan)) + } + + #[tokio::test] + async fn plan_gap_fill() -> Result<()> { + // show that the optimizer rule can fire and that physical + // planning will succeed. + let sql = "SELECT date_bin_gapfill(interval '1 minute', time, timestamp '1970-01-01T00:00:00Z') AS minute, avg(temp)\ + \nFROM temps\ + \nWHERE time >= '1980-01-01T00:00:00Z' and time < '1981-01-01T00:00:00Z'\ + \nGROUP BY minute;"; + + let explain = format_explain(sql).await?; + insta::assert_yaml_snapshot!( + explain, + @r###" + --- + - " ProjectionExec: expr=[date_bin_gapfill(IntervalMonthDayNano(\"60000000000\"),temps.time,Utf8(\"1970-01-01T00:00:00Z\"))@0 as minute, AVG(temps.temp)@1 as AVG(temps.temp)]" + - " GapFillExec: group_expr=[date_bin_gapfill(IntervalMonthDayNano(\"60000000000\"),temps.time,Utf8(\"1970-01-01T00:00:00Z\"))@0], aggr_expr=[AVG(temps.temp)@1], stride=60000000000, time_range=Included(\"315532800000000000\")..Excluded(\"347155200000000000\")" + - " SortExec: expr=[date_bin_gapfill(IntervalMonthDayNano(\"60000000000\"),temps.time,Utf8(\"1970-01-01T00:00:00Z\"))@0 ASC]" + - " AggregateExec: mode=Final, gby=[date_bin_gapfill(IntervalMonthDayNano(\"60000000000\"),temps.time,Utf8(\"1970-01-01T00:00:00Z\"))@0 as date_bin_gapfill(IntervalMonthDayNano(\"60000000000\"),temps.time,Utf8(\"1970-01-01T00:00:00Z\"))], aggr=[AVG(temps.temp)]" + - " AggregateExec: mode=Partial, gby=[date_bin(60000000000, time@0, 0) as date_bin_gapfill(IntervalMonthDayNano(\"60000000000\"),temps.time,Utf8(\"1970-01-01T00:00:00Z\"))], aggr=[AVG(temps.temp)]" + - " EmptyExec" + "### + ); + Ok(()) + } + + #[tokio::test] + async fn gap_fill_exec_sort_order() -> Result<()> { + // The call to `date_bin_gapfill` should be last in the SortExec + // expressions, even though it was not last on the SELECT list + // or the GROUP BY clause. + let sql = "SELECT \ + \n loc,\ + \n date_bin_gapfill(interval '1 minute', time, timestamp '1970-01-01T00:00:00Z') AS minute,\ + \n concat('zz', loc) AS loczz,\ + \n avg(temp)\ + \nFROM temps\ + \nWHERE time >= '1980-01-01T00:00:00Z' and time < '1981-01-01T00:00:00Z' + \nGROUP BY loc, minute, loczz;"; + + let explain = format_explain(sql).await?; + insta::assert_yaml_snapshot!( + explain, + @r###" + --- + - " ProjectionExec: expr=[loc@0 as loc, date_bin_gapfill(IntervalMonthDayNano(\"60000000000\"),temps.time,Utf8(\"1970-01-01T00:00:00Z\"))@1 as minute, concat(Utf8(\"zz\"),temps.loc)@2 as loczz, AVG(temps.temp)@3 as AVG(temps.temp)]" + - " GapFillExec: group_expr=[loc@0, date_bin_gapfill(IntervalMonthDayNano(\"60000000000\"),temps.time,Utf8(\"1970-01-01T00:00:00Z\"))@1, concat(Utf8(\"zz\"),temps.loc)@2], aggr_expr=[AVG(temps.temp)@3], stride=60000000000, time_range=Included(\"315532800000000000\")..Excluded(\"347155200000000000\")" + - " SortExec: expr=[loc@0 ASC,concat(Utf8(\"zz\"),temps.loc)@2 ASC,date_bin_gapfill(IntervalMonthDayNano(\"60000000000\"),temps.time,Utf8(\"1970-01-01T00:00:00Z\"))@1 ASC]" + - " AggregateExec: mode=Final, gby=[loc@0 as loc, date_bin_gapfill(IntervalMonthDayNano(\"60000000000\"),temps.time,Utf8(\"1970-01-01T00:00:00Z\"))@1 as date_bin_gapfill(IntervalMonthDayNano(\"60000000000\"),temps.time,Utf8(\"1970-01-01T00:00:00Z\")), concat(Utf8(\"zz\"),temps.loc)@2 as concat(Utf8(\"zz\"),temps.loc)], aggr=[AVG(temps.temp)]" + - " AggregateExec: mode=Partial, gby=[loc@1 as loc, date_bin(60000000000, time@0, 0) as date_bin_gapfill(IntervalMonthDayNano(\"60000000000\"),temps.time,Utf8(\"1970-01-01T00:00:00Z\")), concat(zz, loc@1) as concat(Utf8(\"zz\"),temps.loc)], aggr=[AVG(temps.temp)]" + - " EmptyExec" + "### + ); + Ok(()) + } +} diff --git a/iox_query/src/exec/gapfill/params.rs b/iox_query/src/exec/gapfill/params.rs new file mode 100644 index 0000000..5e9d0c4 --- /dev/null +++ b/iox_query/src/exec/gapfill/params.rs @@ -0,0 +1,392 @@ +//! Evaluate the parameters to be used for gap filling. +use std::ops::Bound; + +use arrow::{ + datatypes::{IntervalMonthDayNanoType, SchemaRef}, + record_batch::RecordBatch, +}; +use chrono::Duration; +use datafusion::{ + error::{DataFusionError, Result}, + physical_expr::datetime_expressions::date_bin, + physical_plan::{expressions::Column, ColumnarValue}, + scalar::ScalarValue, +}; +use hashbrown::HashMap; + +use super::{try_map_bound, try_map_range, FillStrategy, GapFillExecParams}; + +/// The parameters to gap filling. Included here are the parameters +/// that remain constant during gap filling, i.e., not the streaming table +/// data, or anything else. +/// When we support `locf` for aggregate columns, that will be tracked here. +#[derive(Clone, Debug, PartialEq)] +pub(crate) struct GapFillParams { + /// The stride in nanoseconds of the timestamps to be output. + pub stride: i64, + /// The first timestamp (inclusive) to be output for each series, + /// in nanoseconds since the epoch. `None` means gap filling should + /// start from the first timestamp in each series. + pub first_ts: Option, + /// The last timestamp (inclusive!) to be output for each series, + /// in nanoseconds since the epoch. + pub last_ts: i64, + /// What to do when filling gaps in aggregate columns. + /// The map is keyed on the columns offset in the schema. + pub fill_strategy: HashMap, +} + +impl GapFillParams { + /// Create a new [GapFillParams] by figuring out the actual values (as native i64) for the stride, + /// first and last timestamp for gap filling. + pub(super) fn try_new(schema: SchemaRef, params: &GapFillExecParams) -> Result { + let batch = RecordBatch::new_empty(schema); + let stride = params.stride.evaluate(&batch)?; + let origin = params + .origin + .as_ref() + .map(|e| e.evaluate(&batch)) + .transpose()?; + + // Evaluate the upper and lower bounds of the time range + let range = try_map_range(¶ms.time_range, |b| { + try_map_bound(b.as_ref(), |pe| { + extract_timestamp_nanos(&pe.evaluate(&batch)?) + }) + })?; + + // Find the smallest timestamp that might appear in the + // range. There might not be one, which is okay. + let first_ts = match range.start { + Bound::Included(v) => Some(v), + Bound::Excluded(v) => Some(v + 1), + Bound::Unbounded => None, + }; + + // Find the largest timestamp that might appear in the + // range + let last_ts = match range.end { + Bound::Included(v) => v, + Bound::Excluded(v) => v - 1, + Bound::Unbounded => { + return Err(DataFusionError::Execution( + "missing upper time bound for gap filling".to_string(), + )) + } + }; + + // Call date_bin on the timestamps to find the first and last time bins + // for each series + let mut args = vec![stride, i64_to_columnar_ts(first_ts)]; + if let Some(v) = origin { + args.push(v) + } + let first_ts = first_ts + .map(|_| extract_timestamp_nanos(&date_bin(&args)?)) + .transpose()?; + args[1] = i64_to_columnar_ts(Some(last_ts)); + let last_ts = extract_timestamp_nanos(&date_bin(&args)?)?; + + let fill_strategy = params + .fill_strategy + .iter() + .map(|(e, fs)| { + let idx = e + .as_any() + .downcast_ref::() + .ok_or(DataFusionError::Internal(format!( + "fill strategy aggr expr was not a column: {e:?}", + )))? + .index(); + Ok((idx, *fs)) + }) + .collect::>>()?; + + Ok(Self { + stride: extract_interval_nanos(&args[0])?, + first_ts, + last_ts, + fill_strategy, + }) + } + + /// Returns the number of rows remaining for a series that starts with first_ts. + pub fn valid_row_count(&self, first_ts: i64) -> usize { + if self.last_ts >= first_ts { + ((self.last_ts - first_ts) / self.stride + 1) as usize + } else { + 0 + } + } +} + +fn i64_to_columnar_ts(i: Option) -> ColumnarValue { + match i { + Some(i) => ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(i), None)), + None => ColumnarValue::Scalar(ScalarValue::Null), + } +} + +fn extract_timestamp_nanos(cv: &ColumnarValue) -> Result { + Ok(match cv { + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(v), _)) => *v, + _ => { + return Err(DataFusionError::Execution( + "gap filling argument must be a scalar timestamp".to_string(), + )) + } + }) +} + +fn extract_interval_nanos(cv: &ColumnarValue) -> Result { + match cv { + ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some(v))) => { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v); + + if months != 0 { + return Err(DataFusionError::Execution( + "gap filling does not support month intervals".to_string(), + )); + } + + let nanos = + (Duration::days(days as i64) + Duration::nanoseconds(nanos)).num_nanoseconds(); + nanos.ok_or_else(|| { + DataFusionError::Execution("gap filling argument is too large".to_string()) + }) + } + _ => Err(DataFusionError::Execution( + "gap filling expects a stride parameter to be a scalar interval".to_string(), + )), + } +} + +#[cfg(test)] +mod tests { + use std::{ + ops::{Bound, Range}, + sync::Arc, + }; + + use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; + use datafusion::{ + datasource::empty::EmptyTable, + error::Result, + physical_plan::{ + expressions::{Column, Literal}, + PhysicalExpr, + }, + scalar::ScalarValue, + }; + use hashbrown::HashMap; + + use crate::exec::{ + gapfill::{FillStrategy, GapFillExec, GapFillExecParams}, + Executor, ExecutorType, + }; + + use super::GapFillParams; + + #[tokio::test] + async fn test_evaluate_params() -> Result<()> { + test_helpers::maybe_start_logging(); + let actual = plan_statement_and_get_params( + "select\ + \n date_bin_gapfill(interval '1 minute', time) minute\ + \nfrom t\ + \nwhere time >= timestamp '1984-01-01T16:00:00Z' - interval '5 minutes'\ + \n and time <= timestamp '1984-01-01T16:00:00Z'\ + \ngroup by minute", + ) + .await?; + let expected = GapFillParams { + stride: 60_000_000_000, // 1 minute + first_ts: Some(441_820_500_000_000_000), // Sunday, January 1, 1984 3:55:00 PM + last_ts: 441_820_800_000_000_000, // Sunday, January 1, 1984 3:59:00 PM + fill_strategy: HashMap::new(), + }; + assert_eq!(expected, actual); + Ok(()) + } + + #[tokio::test] + async fn test_evaluate_params_default_origin() -> Result<()> { + // as above but the default origin is explicity specified. + test_helpers::maybe_start_logging(); + let actual = plan_statement_and_get_params( + "select\ + \n date_bin_gapfill(interval '1 minute', time, timestamp '1970-01-01T00:00:00Z') minute\ + \nfrom t\ + \nwhere time >= timestamp '1984-01-01T16:00:00Z' - interval '5 minutes'\ + \n and time <= timestamp '1984-01-01T16:00:00Z'\ + \ngroup by minute", + ).await?; + let expected = GapFillParams { + stride: 60_000_000_000, // 1 minute + first_ts: Some(441_820_500_000_000_000), // Sunday, January 1, 1984 3:55:00 PM + last_ts: 441_820_800_000_000_000, // Sunday, January 1, 1984 3:59:00 PM + fill_strategy: HashMap::new(), + }; + assert_eq!(expected, actual); + Ok(()) + } + + #[tokio::test] + async fn test_evaluate_params_exclude_end() -> Result<()> { + test_helpers::maybe_start_logging(); + let actual = plan_statement_and_get_params( + "select\ + \n date_bin_gapfill(interval '1 minute', time) minute\ + \nfrom t\ + \nwhere time >= timestamp '1984-01-01T16:00:00Z' - interval '5 minutes'\ + \n and time < timestamp '1984-01-01T16:00:00Z'\ + \ngroup by minute", + ) + .await?; + let expected = GapFillParams { + stride: 60_000_000_000, // 1 minute + first_ts: Some(441_820_500_000_000_000), // Sunday, January 1, 1984 3:55:00 PM + // Last bin at 16:00 is excluded + last_ts: 441_820_740_000_000_000, // Sunday, January 1, 1984 3:59:00 PM + fill_strategy: HashMap::new(), + }; + assert_eq!(expected, actual); + Ok(()) + } + + #[tokio::test] + async fn test_evaluate_params_exclude_start() -> Result<()> { + test_helpers::maybe_start_logging(); + let actual = plan_statement_and_get_params( + "select\ + \n date_bin_gapfill(interval '1 minute', time) minute\ + \nfrom t\ + \nwhere time > timestamp '1984-01-01T16:00:00Z' - interval '5 minutes'\ + \n and time <= timestamp '1984-01-01T16:00:00Z'\ + \ngroup by minute", + ) + .await?; + let expected = GapFillParams { + stride: 60_000_000_000, // 1 minute + // First bin not exluded since it truncates to 15:55:00 + first_ts: Some(441_820_500_000_000_000), // Sunday, January 1, 1984 3:55:00 PM + last_ts: 441_820_800_000_000_000, // Sunday, January 1, 1984 3:59:00 PM + fill_strategy: HashMap::new(), + }; + assert_eq!(expected, actual); + Ok(()) + } + + #[tokio::test] + async fn test_evaluate_params_origin() -> Result<()> { + test_helpers::maybe_start_logging(); + let actual = plan_statement_and_get_params( + // origin is 9s after the epoch + "select\ + \n date_bin_gapfill(interval '1 minute', time, timestamp '1970-01-01T00:00:09Z') minute\ + \nfrom t\ + \nwhere time >= timestamp '1984-01-01T16:00:00Z' - interval '5 minutes'\ + \n and time <= timestamp '1984-01-01T16:00:00Z'\ + \ngroup by minute", + ).await?; + let expected = GapFillParams { + stride: 60_000_000_000, // 1 minute + first_ts: Some(441_820_449_000_000_000), // Sunday, January 1, 1984 3:54:09 PM + last_ts: 441_820_749_000_000_000, // Sunday, January 1, 1984 3:59:09 PM + fill_strategy: HashMap::new(), + }; + assert_eq!(expected, actual); + Ok(()) + } + + fn interval(ns: i64) -> Arc { + Arc::new(Literal::new(ScalarValue::new_interval_mdn(0, 0, ns))) + } + + fn timestamp(ns: i64) -> Arc { + Arc::new(Literal::new(ScalarValue::TimestampNanosecond( + Some(ns), + None, + ))) + } + + #[test] + fn test_params_no_start() { + let exec_params = GapFillExecParams { + stride: interval(1_000_000_000), + time_column: Column::new("time", 0), + origin: None, + time_range: Range { + start: Bound::Unbounded, + end: Bound::Excluded(timestamp(20_000_000_000)), + }, + fill_strategy: std::iter::once(( + Arc::new(Column::new("a0", 1)) as Arc, + FillStrategy::Null, + )) + .collect(), + }; + + let actual = GapFillParams::try_new(schema().into(), &exec_params).unwrap(); + assert_eq!( + GapFillParams { + stride: 1_000_000_000, + first_ts: None, + last_ts: 19_000_000_000, + fill_strategy: simple_fill_strategy(), + }, + actual + ); + } + + #[test] + #[allow(clippy::reversed_empty_ranges)] + fn test_params_row_count() -> Result<()> { + test_helpers::maybe_start_logging(); + let params = GapFillParams { + stride: 10, + first_ts: Some(1000), + last_ts: 1050, + fill_strategy: simple_fill_strategy(), + }; + + assert_eq!(6, params.valid_row_count(1000)); + assert_eq!(0, params.valid_row_count(1100)); + Ok(()) + } + + fn schema() -> Schema { + Schema::new(vec![ + Field::new( + "time", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new( + "other_time", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new("loc", DataType::Utf8, false), + Field::new("temp", DataType::Float64, false), + ]) + } + + async fn plan_statement_and_get_params(sql: &str) -> Result { + let executor = Executor::new_testing(); + let context = executor.new_context(ExecutorType::Query); + context + .inner() + .register_table("t", Arc::new(EmptyTable::new(Arc::new(schema()))))?; + let physical_plan = context.sql_to_physical_plan(sql).await?; + let gapfill_node = &physical_plan.children()[0]; + let gapfill_node = gapfill_node.as_any().downcast_ref::().unwrap(); + let exec_params = &gapfill_node.params; + let schema = schema(); + GapFillParams::try_new(schema.into(), exec_params) + } + + fn simple_fill_strategy() -> HashMap { + std::iter::once((1, FillStrategy::Null)).collect() + } +} diff --git a/iox_query/src/exec/gapfill/stream.rs b/iox_query/src/exec/gapfill/stream.rs new file mode 100644 index 0000000..499de06 --- /dev/null +++ b/iox_query/src/exec/gapfill/stream.rs @@ -0,0 +1,284 @@ +//! Implementation of [Stream] that performs gap-filling on tables. +use std::{ + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use arrow::{ + array::{ArrayRef, TimestampNanosecondArray}, + datatypes::SchemaRef, + record_batch::RecordBatch, +}; +use arrow_util::optimize::optimize_dictionaries; +use datafusion::{ + error::{DataFusionError, Result}, + execution::memory_pool::MemoryReservation, + physical_plan::{ + expressions::Column, + metrics::{BaselineMetrics, RecordOutput}, + ExecutionPlan, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, + }, +}; +use futures::{ready, Stream, StreamExt}; + +use super::{algo::GapFiller, buffered_input::BufferedInput, params::GapFillParams, GapFillExec}; + +/// An implementation of a gap-filling operator that uses the [Stream] trait. +/// +/// This type takes responsibility for: +/// - Reading input record batches +/// - Accounting for memory +/// - Extracting arrays for processing by [`GapFiller`] +/// - Recording metrics +/// - Sending record batches to next operator (by implementing [`Self::poll_next']) +#[allow(dead_code)] +pub(super) struct GapFillStream { + /// The schema of the input and output. + schema: SchemaRef, + /// The column from the input that contains the timestamps for each row. + /// This column has already had `date_bin` applied to it by a previous `Aggregate` + /// operator. + time_expr: Arc, + /// The other columns from the input that appeared in the GROUP BY clause of the + /// original query. + group_expr: Vec>, + /// The aggregate columns from the select list of the original query. + aggr_expr: Vec>, + /// The producer of the input record batches. + input: SendableRecordBatchStream, + /// Input that has been read from the iput stream. + buffered_input: BufferedInput, + /// The thing that does the gap filling. + gap_filler: GapFiller, + /// This is true as long as there are more input record batches to read from `input`. + more_input: bool, + /// For tracking memory. + reservation: MemoryReservation, + /// Baseline metrics. + baseline_metrics: BaselineMetrics, +} + +impl GapFillStream { + /// Creates a new GapFillStream. + pub fn try_new( + exec: &GapFillExec, + batch_size: usize, + input: SendableRecordBatchStream, + reservation: MemoryReservation, + metrics: BaselineMetrics, + ) -> Result { + let schema = exec.schema(); + let GapFillExec { + sort_expr, + aggr_expr, + params, + .. + } = exec; + + if sort_expr.is_empty() { + return Err(DataFusionError::Internal( + "empty sort_expr vector for gap filling; should have at least a time expression" + .to_string(), + )); + } + let mut group_expr = sort_expr + .iter() + .map(|se| Arc::clone(&se.expr)) + .collect::>(); + let aggr_expr = aggr_expr.to_owned(); + let time_expr = group_expr.split_off(group_expr.len() - 1).pop().unwrap(); + + let group_cols = group_expr.iter().map(expr_to_index).collect::>(); + let params = GapFillParams::try_new(Arc::clone(&schema), params)?; + let buffered_input = BufferedInput::new(¶ms, group_cols); + + let gap_filler = GapFiller::new(params, batch_size); + Ok(Self { + schema, + time_expr, + group_expr, + aggr_expr, + input, + buffered_input, + gap_filler, + more_input: true, + reservation, + baseline_metrics: metrics, + }) + } +} + +impl RecordBatchStream for GapFillStream { + fn schema(&self) -> arrow::datatypes::SchemaRef { + Arc::clone(&self.schema) + } +} + +impl Stream for GapFillStream { + type Item = Result; + + /// Produces a gap-filled record batch from its input stream. + /// + /// For details on implementation, see [`GapFiller`]. + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let last_output_row_offset = self.gap_filler.last_output_row_offset(); + while self.more_input && self.buffered_input.need_more(last_output_row_offset)? { + match ready!(self.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + self.reservation.try_grow(batch.get_array_memory_size())?; + self.buffered_input.push(batch); + } + Some(Err(e)) => { + return Poll::Ready(Some(Err(e))); + } + None => { + self.more_input = false; + } + } + } + + let input_batch = match self.take_buffered_input() { + Ok(None) => return Poll::Ready(None), + Ok(Some(input_batch)) => { + // If we have consumed all of our input, and there is no more work + if self.gap_filler.done(input_batch.num_rows()) { + // leave the input batch taken so that its reference + // count goes to zero. + self.reservation.shrink(input_batch.get_array_memory_size()); + return Poll::Ready(None); + } + + input_batch + } + Err(e) => return Poll::Ready(Some(Err(e))), + }; + + match self.process(input_batch) { + Ok((output_batch, remaining_input_batch)) => { + self.buffered_input.push(remaining_input_batch); + + self.reservation + .shrink(output_batch.get_array_memory_size()); + Poll::Ready(Some(Ok(output_batch))) + } + Err(e) => Poll::Ready(Some(Err(e))), + } + } +} + +impl GapFillStream { + /// If any buffered input batches are present, concatenates it all together + /// and returns an owned batch to the caller, leaving `self.buffered_input_batches` empty. + fn take_buffered_input(&mut self) -> Result> { + let batches = self.buffered_input.take(); + if batches.is_empty() { + return Ok(None); + } + + let old_size = batches.iter().map(|rb| rb.get_array_memory_size()).sum(); + + let mut batch = arrow::compute::concat_batches(&self.schema, &batches) + .map_err(|err| DataFusionError::ArrowError(err, None))?; + self.reservation.try_grow(batch.get_array_memory_size())?; + + if batches.len() > 1 { + // Optimize the dictionaries. The output of this operator uses the take kernel to produce + // its output. Since the input batches will usually be smaller than the output, it should + // be less work to optimize here vs optimizing the output. + batch = optimize_dictionaries(&batch) + .map_err(|err| DataFusionError::ArrowError(err, None))?; + } + + self.reservation.shrink(old_size); + Ok(Some(batch)) + } + + /// Produces a 2-tuple of [RecordBatch]es: + /// - The gap-filled output + /// - Remaining buffered input + fn process(&mut self, mut input_batch: RecordBatch) -> Result<(RecordBatch, RecordBatch)> { + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + + let input_time_array = self + .time_expr + .evaluate(&input_batch)? + .into_array(input_batch.num_rows())?; + let input_time_array: &TimestampNanosecondArray = input_time_array + .as_any() + .downcast_ref() + .ok_or(DataFusionError::Internal( + "time array must be a TimestampNanosecondArray".to_string(), + ))?; + let input_time_array = (expr_to_index(&self.time_expr), input_time_array); + + let group_arrays = self.group_arrays(&input_batch)?; + let aggr_arrays = self.aggr_arrays(&input_batch)?; + + let timer = elapsed_compute.timer(); + let output_batch = self + .gap_filler + .build_gapfilled_output( + Arc::clone(&self.schema), + input_time_array, + &group_arrays, + &aggr_arrays, + ) + .record_output(&self.baseline_metrics)?; + timer.done(); + + self.reservation + .try_grow(output_batch.get_array_memory_size())?; + + // Slice the input to just what is needed moving forward, with one context + // row before the next input offset. + input_batch = self.gap_filler.slice_input_batch(input_batch)?; + + Ok((output_batch, input_batch)) + } + + /// Produces the arrays for the group columns in the input. + /// The first item in the 2-tuple is the arrays offset in the schema. + fn group_arrays(&self, input_batch: &RecordBatch) -> Result> { + self.group_expr + .iter() + .map(|e| { + Ok(( + expr_to_index(e), + e.evaluate(input_batch)? + .into_array(input_batch.num_rows())?, + )) + }) + .collect::>>() + } + + /// Produces the arrays for the aggregate columns in the input. + /// The first item in the 2-tuple is the arrays offset in the schema. + fn aggr_arrays(&self, input_batch: &RecordBatch) -> Result> { + self.aggr_expr + .iter() + .map(|e| { + Ok(( + expr_to_index(e), + e.evaluate(input_batch)? + .into_array(input_batch.num_rows())?, + )) + }) + .collect::>>() + } +} + +/// Returns the index of the given expression in the schema, +/// assuming that it is a column. +/// +/// # Panic +/// Panics if the expression is not a column. +fn expr_to_index(expr: &Arc) -> usize { + expr.as_any() + .downcast_ref::() + .expect("all exprs should be columns") + .index() +} diff --git a/iox_query/src/exec/metrics.rs b/iox_query/src/exec/metrics.rs new file mode 100644 index 0000000..7a39768 --- /dev/null +++ b/iox_query/src/exec/metrics.rs @@ -0,0 +1,52 @@ +use std::{ + borrow::Cow, + sync::{Arc, Weak}, +}; + +use datafusion::execution::memory_pool::MemoryPool; +use metric::{Attributes, Instrument, MetricKind, Observation, Reporter}; + +/// Hooks DataFusion [`MemoryPool`] into our [`metric`] crate. +#[derive(Debug, Clone)] +pub struct DataFusionMemoryPoolMetricsBridge { + pool: Weak, + limit: usize, +} + +impl DataFusionMemoryPoolMetricsBridge { + /// Register new pool. + pub fn new(pool: &Arc, limit: usize) -> Self { + Self { + pool: Arc::downgrade(pool), + limit, + } + } +} + +impl Instrument for DataFusionMemoryPoolMetricsBridge { + fn report(&self, reporter: &mut dyn Reporter) { + reporter.start_metric( + "datafusion_mem_pool_bytes", + "Number of bytes within the datafusion memory pool", + MetricKind::U64Gauge, + ); + let Some(pool_arc) = self.pool.upgrade() else { + return; + }; + + reporter.report_observation( + &Attributes::from([("state", Cow::Borrowed("limit"))]), + Observation::U64Gauge(self.limit as u64), + ); + + reporter.report_observation( + &Attributes::from([("state", Cow::Borrowed("reserved"))]), + Observation::U64Gauge(pool_arc.reserved() as u64), + ); + reporter.finish_metric(); + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} diff --git a/iox_query/src/exec/non_null_checker.rs b/iox_query/src/exec/non_null_checker.rs new file mode 100644 index 0000000..8a60bd7 --- /dev/null +++ b/iox_query/src/exec/non_null_checker.rs @@ -0,0 +1,478 @@ +//! This module contains code for the "NonNullChecker" DataFusion +//! extension plan node +//! +//! A NonNullChecker node takes an arbitrary input array and produces +//! a single string output column that contains +//! +//! 1. A single string if any of the input columns are non-null +//! 2. zero rows if all of the input columns are null +//! +//! For this input: +//! +//! ColA | ColB | ColC +//! ------+------+------ +//! 1 | NULL | NULL +//! 2 | 2 | NULL +//! 3 | 2 | NULL +//! +//! The output would be (given 'the_value' was provided to `NonNullChecker` node) +//! +//! non_null_column +//! ----------------- +//! the_value +//! +//! However, for this input (All NULL) +//! +//! ColA | ColB | ColC +//! ------+------+------ +//! NULL | NULL | NULL +//! NULL | NULL | NULL +//! NULL | NULL | NULL +//! +//! There would be no output rows +//! +//! non_null_column +//! ----------------- +//! +//! This operation can be used to implement the table_name metadata query + +use std::{ + fmt::{self, Debug}, + sync::Arc, +}; + +use arrow::{ + array::{new_empty_array, StringArray}, + datatypes::{DataType, Field, Schema, SchemaRef}, + record_batch::RecordBatch, +}; +use datafusion::logical_expr::expr_vec_fmt; +use datafusion::{ + common::{DFSchemaRef, ToDFSchema}, + error::{DataFusionError, Result}, + execution::context::TaskContext, + logical_expr::{Expr, LogicalPlan, UserDefinedLogicalNodeCore}, + physical_plan::{ + expressions::PhysicalSortExpr, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + SendableRecordBatchStream, Statistics, + }, +}; + +use datafusion_util::{watch::WatchedTask, AdapterStream}; +use observability_deps::tracing::debug; +use tokio::sync::mpsc; +use tokio_stream::StreamExt; + +/// Implements the NonNullChecker operation as described in this module's documentation +#[derive(Hash, PartialEq, Eq)] +pub struct NonNullCheckerNode { + input: LogicalPlan, + schema: DFSchemaRef, + /// these expressions represent what columns are "used" by this + /// node (in this case all of them) -- columns that are not used + /// are optimzied away by datafusion. + exprs: Vec, + + /// The value to produce if there are any non null Inputs + value: Arc, +} + +impl NonNullCheckerNode { + /// Creates a new NonNullChecker node + /// + /// # Panics + /// If the input schema is empty + pub fn new(value: &str, input: LogicalPlan) -> Self { + let schema = make_non_null_checker_output_schema(); + + // Form exprs that refer to all of our input columns (so that + // datafusion knows not to opimize them away) + let exprs = input + .schema() + .fields() + .iter() + .map(|field| Expr::Column(field.qualified_column())) + .collect::>(); + + assert!(!exprs.is_empty(), "NonNullChecker: input schema was empty"); + + Self { + input, + schema, + exprs, + value: value.into(), + } + } + + /// Return the value associated with this checker + pub fn value(&self) -> Arc { + Arc::clone(&self.value) + } +} + +impl Debug for NonNullCheckerNode { + /// Use explain format for the Debug format. + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.fmt_for_explain(f) + } +} + +impl UserDefinedLogicalNodeCore for NonNullCheckerNode { + fn name(&self) -> &str { + "NonNullChecker" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + /// Schema for Pivot is a single string + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + self.exprs.clone() + } + + /// For example: `NonNullChecker('the_value'), exprs=[foo]` + fn fmt_for_explain(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}('{}') exprs={}", + self.name(), + self.value, + expr_vec_fmt!(self.exprs) + ) + } + + fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + assert_eq!(inputs.len(), 1, "NonNullChecker: input sizes inconsistent"); + assert_eq!( + exprs.len(), + self.exprs.len(), + "NonNullChecker: expression sizes inconsistent" + ); + Self::new(self.value.as_ref(), inputs[0].clone()) + } +} + +// ------ The implementation of NonNullChecker code follows ----- + +/// Create the schema describing the output +pub fn make_non_null_checker_output_schema() -> DFSchemaRef { + let nullable = false; + Schema::new(vec![Field::new( + "non_null_column", + DataType::Utf8, + nullable, + )]) + .to_dfschema_ref() + .unwrap() +} + +/// Physical operator that implements the NonNullChecker operation aginst +/// data types +pub struct NonNullCheckerExec { + input: Arc, + /// Output schema + schema: SchemaRef, + /// The value to produce if there are any non null Inputs + value: Arc, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, +} + +impl NonNullCheckerExec { + pub fn new(input: Arc, schema: SchemaRef, value: Arc) -> Self { + Self { + input, + schema, + value, + metrics: ExecutionPlanMetricsSet::new(), + } + } +} + +impl Debug for NonNullCheckerExec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "NonNullCheckerExec") + } +} + +impl ExecutionPlan for NonNullCheckerExec { + fn as_any(&self) -> &(dyn std::any::Any + 'static) { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn output_partitioning(&self) -> Partitioning { + use Partitioning::*; + match self.input.output_partitioning() { + RoundRobinBatch(num_partitions) => RoundRobinBatch(num_partitions), + // as this node transforms the output schema, whatever partitioning + // was present on the input is lost on the output + Hash(_, num_partitions) => UnknownPartitioning(num_partitions), + UnknownPartitioning(num_partitions) => UnknownPartitioning(num_partitions), + } + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn required_input_distribution(&self) -> Vec { + vec![Distribution::UnspecifiedDistribution] + } + + fn children(&self) -> Vec> { + vec![Arc::clone(&self.input)] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + match children.len() { + 1 => Ok(Arc::new(Self { + input: Arc::clone(&children[0]), + schema: Arc::clone(&self.schema), + metrics: ExecutionPlanMetricsSet::new(), + value: Arc::clone(&self.value), + })), + _ => Err(DataFusionError::Internal( + "NonNullCheckerExec wrong number of children".to_string(), + )), + } + } + + /// Execute one partition and return an iterator over RecordBatch + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + debug!(partition, "Start NonNullCheckerExec::execute"); + if self.output_partitioning().partition_count() <= partition { + return Err(DataFusionError::Internal(format!( + "NonNullCheckerExec invalid partition {partition}" + ))); + } + + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + let input_stream = self.input.execute(partition, context)?; + + let (tx, rx) = mpsc::channel(1); + + let fut = check_for_nulls( + input_stream, + Arc::clone(&self.schema), + baseline_metrics, + Arc::clone(&self.value), + tx.clone(), + ); + + // A second task watches the output of the worker task and + // reports errors + let handle = WatchedTask::new(fut, vec![tx], "non_null_checker"); + + debug!(partition, "End NonNullCheckerExec::execute"); + Ok(AdapterStream::adapt(self.schema(), rx, handle)) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + Ok(Statistics::new_unknown(&self.schema())) + } +} + +impl DisplayAs for NonNullCheckerExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "NonNullCheckerExec") + } + } + } +} + +async fn check_for_nulls( + mut input_stream: SendableRecordBatchStream, + schema: SchemaRef, + baseline_metrics: BaselineMetrics, + value: Arc, + tx: mpsc::Sender>, +) -> Result<(), DataFusionError> { + while let Some(input_batch) = input_stream.next().await.transpose()? { + let timer = baseline_metrics.elapsed_compute().timer(); + + if input_batch + .columns() + .iter() + .any(|arr| arr.null_count() != arr.len()) + { + // found a non null in input, return value + let arr: StringArray = vec![Some(value.as_ref())].into(); + + let output_batch = RecordBatch::try_new(schema, vec![Arc::new(arr)])?; + // ignore errors on sending (means receiver hung up) + std::mem::drop(timer); + tx.send(Ok(output_batch)).await.ok(); + return Ok(()); + } + // else keep looking + } + // if we got here, did not see any non null values. So + // send back an empty record batch + let output_batch = RecordBatch::try_new(schema, vec![new_empty_array(&DataType::Utf8)])?; + + // ignore errors on sending (means receiver hung up) + tx.send(Ok(output_batch)).await.ok(); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ArrayRef, StringArray}; + use arrow_util::assert_batches_eq; + use datafusion::physical_plan::memory::MemoryExec; + use datafusion_util::test_collect; + + #[tokio::test] + async fn test_single_column_non_null() { + let t1 = StringArray::from(vec![Some("a"), Some("c"), Some("c")]); + let batch = RecordBatch::try_from_iter(vec![("t1", Arc::new(t1) as ArrayRef)]).unwrap(); + + let results = check("the_value", vec![batch]).await; + + let expected = vec![ + "+-----------------+", + "| non_null_column |", + "+-----------------+", + "| the_value |", + "+-----------------+", + ]; + assert_batches_eq!(&expected, &results); + } + + #[tokio::test] + async fn test_single_column_null() { + let t1 = StringArray::from(vec![None::<&str>, None, None]); + let batch = RecordBatch::try_from_iter(vec![("t1", Arc::new(t1) as ArrayRef)]).unwrap(); + + let results = check("the_value", vec![batch]).await; + + let expected = vec![ + "+-----------------+", + "| non_null_column |", + "+-----------------+", + "+-----------------+", + ]; + assert_batches_eq!(&expected, &results); + } + + #[tokio::test] + async fn test_multi_column_non_null() { + let t1 = StringArray::from(vec![None::<&str>, None, None]); + let t2 = StringArray::from(vec![None::<&str>, None, Some("c")]); + let batch = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("t2", Arc::new(t2) as ArrayRef), + ]) + .unwrap(); + + let results = check("the_value", vec![batch]).await; + + let expected = vec![ + "+-----------------+", + "| non_null_column |", + "+-----------------+", + "| the_value |", + "+-----------------+", + ]; + assert_batches_eq!(&expected, &results); + } + + #[tokio::test] + async fn test_multi_column_null() { + let t1 = StringArray::from(vec![None::<&str>, None, None]); + let t2 = StringArray::from(vec![None::<&str>, None, None]); + let batch = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("t2", Arc::new(t2) as ArrayRef), + ]) + .unwrap(); + + let results = check("the_value", vec![batch]).await; + + let expected = vec![ + "+-----------------+", + "| non_null_column |", + "+-----------------+", + "+-----------------+", + ]; + assert_batches_eq!(&expected, &results); + } + + #[tokio::test] + async fn test_multi_column_second_batch_non_null() { + // this time only the second batch has a non null value + let t1 = StringArray::from(vec![None::<&str>, None, None]); + let t2 = StringArray::from(vec![None::<&str>, None, None]); + + let batch1 = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("t2", Arc::new(t2) as ArrayRef), + ]) + .unwrap(); + + let t1 = StringArray::from(vec![None::<&str>]); + let t2 = StringArray::from(vec![Some("f")]); + + let batch2 = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("t2", Arc::new(t2) as ArrayRef), + ]) + .unwrap(); + + let results = check("another_value", vec![batch1, batch2]).await; + + let expected = vec![ + "+-----------------+", + "| non_null_column |", + "+-----------------+", + "| another_value |", + "+-----------------+", + ]; + assert_batches_eq!(&expected, &results); + } + + /// Run the input through the checker and return results + async fn check(value: &str, input: Vec) -> Vec { + test_helpers::maybe_start_logging(); + + // Setup in memory stream + let schema = input[0].schema(); + let projection = None; + let input = Arc::new(MemoryExec::try_new(&[input], schema, projection).unwrap()); + + // Create and run the checker + let schema: Schema = make_non_null_checker_output_schema().as_ref().into(); + let exec = Arc::new(NonNullCheckerExec::new( + input, + Arc::new(schema), + value.into(), + )); + + test_collect(exec as Arc).await + } +} diff --git a/iox_query/src/exec/query_tracing.rs b/iox_query/src/exec/query_tracing.rs new file mode 100644 index 0000000..de639c3 --- /dev/null +++ b/iox_query/src/exec/query_tracing.rs @@ -0,0 +1,703 @@ +//! This module contains the code to map DataFusion metrics to `Span`s +//! for use in distributed tracing (e.g. Jaeger) + +use arrow::record_batch::RecordBatch; +use chrono::{DateTime, Utc}; +use datafusion::error::DataFusionError; +use datafusion::physical_plan::{ + metrics::{MetricValue, MetricsSet}, + DisplayFormatType, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, +}; +use futures::StreamExt; +use hashbrown::HashMap; +use observability_deps::tracing::debug; +use std::{fmt, sync::Arc}; +use trace::span::{Span, SpanRecorder}; + +const PER_PARTITION_TRACING_ENABLE_ENV: &str = "INFLUXDB_IOX_PER_PARTITION_TRACING"; +fn per_partition_tracing() -> bool { + use std::sync::atomic::{AtomicU8, Ordering}; + static TRACING_ENABLED: AtomicU8 = AtomicU8::new(u8::MAX); + + match TRACING_ENABLED.load(Ordering::Relaxed) { + u8::MAX => { + let val = std::env::var(PER_PARTITION_TRACING_ENABLE_ENV) + .ok() + .and_then(|x| x.parse::().ok()) + .map(Into::into) + .unwrap_or(false); + + TRACING_ENABLED.store(val as u8, Ordering::Relaxed); + val + } + x => x != 0, + } +} + +/// Stream wrapper that records DataFusion `MetricSets` into IOx +/// [`Span`]s when it is dropped. +pub(crate) struct TracedStream { + inner: SendableRecordBatchStream, + span_recorder: SpanRecorder, + physical_plan: Arc, +} + +impl TracedStream { + /// Return a stream that records DataFusion `MetricSets` from + /// `physical_plan` into `span` when dropped. + pub(crate) fn new( + inner: SendableRecordBatchStream, + span: Option, + physical_plan: Arc, + ) -> Self { + Self { + inner, + span_recorder: SpanRecorder::new(span), + physical_plan, + } + } +} + +impl RecordBatchStream for TracedStream { + fn schema(&self) -> arrow::datatypes::SchemaRef { + self.inner.schema() + } +} + +impl futures::Stream for TracedStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_next_unpin(cx) + } +} + +impl Drop for TracedStream { + fn drop(&mut self) { + if let Some(span) = self.span_recorder.span() { + let default_end_time = Utc::now(); + let per_partition_tracing = per_partition_tracing(); + send_metrics_to_tracing( + default_end_time, + span, + self.physical_plan.as_ref(), + per_partition_tracing, + ); + } + } +} + +/// This function translates data in DataFusion `MetricSets` into IOx +/// [`Span`]s. It records a snapshot of the current state of the +/// DataFusion metrics, so it should only be invoked *after* a plan is +/// fully `collect`ed. +/// +/// Each `ExecutionPlan` in the plan gets its own new [`Span`] that covers +/// the time spent executing its partitions and its children +/// +/// Each `ExecutionPlan` also has a new [`Span`] for each of its +/// partitions that collected metrics +/// +/// The start and end time of the span are taken from the +/// ExecutionPlan's metrics, falling back to the parent span's +/// timestamps if there are no metrics +/// +/// Span metadata is used to record: +/// 1. If the ExecutionPlan had no metrics +/// 2. The total number of rows produced by the ExecutionPlan (if available) +/// 3. The elapsed compute time taken by the ExecutionPlan +pub fn send_metrics_to_tracing( + default_end_time: DateTime, + parent_span: &Span, + physical_plan: &dyn ExecutionPlan, + per_partition_tracing: bool, +) { + // Something like this when one_line is contributed back upstream + //let plan_name = physical_plan.displayable().one_line().to_string(); + let desc = one_line(physical_plan).to_string(); + let operator_name: String = desc.chars().take_while(|x| *x != ':').collect(); + + // Get the timings of the parent operator + let parent_start_time = parent_span.start.unwrap_or(default_end_time); + let parent_end_time = parent_span.end.unwrap_or(default_end_time); + + // A span for the operation, this is the aggregate of all the partition spans + let mut operator_span = parent_span.child(operator_name.clone()); + operator_span.metadata.insert("desc".into(), desc.into()); + + let mut operator_metrics = SpanMetrics { + output_rows: None, + elapsed_compute_nanos: None, + }; + + // The total duration for this span and all its children and partitions + let mut operator_start_time = DateTime::::MAX_UTC; + let mut operator_end_time = DateTime::::MIN_UTC; + + match physical_plan.metrics() { + None => { + // this DataFusion node had no metrics, so record that in + // metadata and use the start/stop time of the parent span + operator_span + .metadata + .insert("missing_statistics".into(), "true".into()); + } + Some(metrics) => { + // Create a separate span for each partition in the operator + for (partition, metrics) in partition_metrics(metrics) { + let (start_ts, end_ts) = get_timestamps(&metrics); + + let partition_start_time = start_ts.unwrap_or(parent_start_time); + let partition_end_time = end_ts.unwrap_or(parent_end_time); + + let partition_metrics = SpanMetrics { + output_rows: metrics.output_rows(), + elapsed_compute_nanos: metrics.elapsed_compute(), + }; + + operator_start_time = operator_start_time.min(partition_start_time); + operator_end_time = operator_end_time.max(partition_end_time); + + // Update the aggregate totals in the operator span + operator_metrics.aggregate_child(&partition_metrics); + + // Generate a span for the partition if + // - these metrics correspond to a partition + // - per partition tracing is enabled + if per_partition_tracing { + if let Some(partition) = partition { + let mut partition_span = + operator_span.child(format!("{operator_name} ({partition})")); + + partition_span.start = Some(partition_start_time); + partition_span.end = Some(partition_end_time); + + partition_metrics.add_to_span(&mut partition_span); + + partition_span.export(); + } + } + } + } + } + + // If we've not encountered any metrics to determine the operator's start + // and end time, use those of the parent + if operator_start_time == DateTime::::MAX_UTC { + operator_start_time = parent_span.start.unwrap_or(default_end_time); + } + + if operator_end_time == DateTime::::MIN_UTC { + operator_end_time = parent_span.end.unwrap_or(default_end_time); + } + + operator_span.start = Some(operator_start_time); + operator_span.end = Some(operator_end_time); + + // recurse + for child in physical_plan.children() { + send_metrics_to_tracing( + operator_end_time, + &operator_span, + child.as_ref(), + per_partition_tracing, + ); + } + + operator_metrics.add_to_span(&mut operator_span); + operator_span.export(); +} + +#[derive(Debug)] +struct SpanMetrics { + output_rows: Option, + elapsed_compute_nanos: Option, +} + +impl SpanMetrics { + fn aggregate_child(&mut self, child: &Self) { + if let Some(rows) = child.output_rows { + *self.output_rows.get_or_insert(0) += rows; + } + + if let Some(nanos) = child.elapsed_compute_nanos { + *self.elapsed_compute_nanos.get_or_insert(0) += nanos; + } + } + + fn add_to_span(&self, span: &mut Span) { + if let Some(rows) = self.output_rows { + span.metadata + .insert("output_rows".into(), (rows as i64).into()); + } + + if let Some(nanos) = self.elapsed_compute_nanos { + span.metadata + .insert("elapsed_compute_nanos".into(), (nanos as i64).into()); + } + } +} + +fn partition_metrics(metrics: MetricsSet) -> HashMap, MetricsSet> { + let mut hashmap = HashMap::<_, MetricsSet>::new(); + for metric in metrics.iter() { + hashmap + .entry(metric.partition()) + .or_default() + .push(Arc::clone(metric)) + } + hashmap +} + +// todo contribute this back upstream to datafusion (add to `DisplayableExecutionPlan`) + +/// Return a `Display`able structure that produces a single line, for +/// this node only (does not recurse to children) +pub fn one_line(plan: &dyn ExecutionPlan) -> impl fmt::Display + '_ { + struct Wrapper<'a> { + plan: &'a dyn ExecutionPlan, + } + impl<'a> fmt::Display for Wrapper<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let t = DisplayFormatType::Default; + self.plan.fmt_as(t, f) + } + } + + Wrapper { plan } +} + +// TODO maybe also contribute these back upstream to datafusion (make +// as a method on MetricsSet) + +/// Return the start, and end timestamps of the metrics set, if any +fn get_timestamps(metrics: &MetricsSet) -> (Option>, Option>) { + let mut start_ts = None; + let mut end_ts = None; + + for metric in metrics.iter() { + if metric.labels().is_empty() { + match metric.value() { + MetricValue::StartTimestamp(ts) => { + if ts.value().is_some() && start_ts.is_some() { + debug!( + ?metric, + ?start_ts, + "WARNING: more than one StartTimestamp metric found" + ) + } + start_ts = ts.value() + } + MetricValue::EndTimestamp(ts) => { + if ts.value().is_some() && end_ts.is_some() { + debug!( + ?metric, + ?end_ts, + "WARNING: more than one EndTimestamp metric found" + ) + } + end_ts = ts.value() + } + _ => {} + } + } + } + + (start_ts, end_ts) +} + +/// Boolean flag that works with environment variables. +#[derive(Debug, Clone, Copy)] +pub enum BooleanFlag { + True, + False, +} + +impl std::str::FromStr for BooleanFlag { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_ascii_lowercase().as_str() { + "yes" | "y" | "true" | "t" | "1" => Ok(Self::True), + "no" | "n" | "false" | "f" | "0" => Ok(Self::False), + _ => Err(format!( + "Invalid boolean flag '{s}'. Valid options: yes, no, y, n, true, false, t, f, 1, 0" + )), + } + } +} + +impl From for bool { + fn from(yes_no: BooleanFlag) -> Self { + matches!(yes_no, BooleanFlag::True) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::TimeZone; + use datafusion::{ + execution::context::TaskContext, + physical_plan::{ + expressions::PhysicalSortExpr, + metrics::{Count, Time, Timestamp}, + DisplayAs, Metric, + }, + }; + use std::{collections::BTreeMap, str::FromStr, sync::Arc, time::Duration}; + use trace::{ctx::SpanContext, span::MetaValue, RingBufferTraceCollector}; + + #[test] + fn name_truncation() { + let name = "Foo: expr nonsense"; + let exec = TestExec::new(name, Default::default()); + + let traces = TraceBuilder::new(); + send_metrics_to_tracing(Utc::now(), &traces.make_span(), &exec, true); + + let spans = traces.spans(); + assert_eq!(spans.len(), 1); + + // name is truncated to the operator name + assert_eq!(spans[0].name, "TestExec - Foo", "span: {spans:#?}"); + } + + // children and time propagation + #[test] + fn children_and_timestamps() { + let ts1 = Utc.timestamp_opt(1, 0).unwrap(); + let ts2 = Utc.timestamp_opt(2, 0).unwrap(); + let ts3 = Utc.timestamp_opt(3, 0).unwrap(); + let ts4 = Utc.timestamp_opt(4, 0).unwrap(); + let ts5 = Utc.timestamp_opt(5, 0).unwrap(); + + let mut many_partition = MetricsSet::new(); + add_time_metrics(&mut many_partition, None, Some(ts2), Some(1)); + add_time_metrics(&mut many_partition, Some(ts2), Some(ts3), Some(2)); + add_time_metrics(&mut many_partition, Some(ts1), None, Some(3)); + + // build this timestamp tree: + // + // exec: [ ts1 -------- ts4] <-- both start and end timestamps + // child1: [ ts2 - ] <-- only start timestamp + // child2: [ ts2 --- ts3] <-- both start and end timestamps + // child3: [ --- ts3] <-- only end timestamps (e.g. bad data) + // child4: [ ] <-- no timestamps + // child5 (1): [ --- ts2] + // child5 (2): [ ts2 --- ts3] + // child5 (4): [ ts1 --- ] + let mut exec = TestExec::new("exec", make_time_metric_set(Some(ts1), Some(ts4), Some(1))); + exec.new_child( + "child1: foo", + make_time_metric_set(Some(ts2), None, Some(1)), + ); + exec.new_child( + "child2: bar", + make_time_metric_set(Some(ts2), Some(ts3), None), + ); + exec.new_child( + "child3: baz", + make_time_metric_set(None, Some(ts3), Some(1)), + ); + exec.new_child("child4: bingo", make_time_metric_set(None, None, Some(1))); + exec.new_child("child5: bongo", many_partition); + + let traces = TraceBuilder::new(); + send_metrics_to_tracing(ts5, &traces.make_span(), &exec, true); + + let spans = traces.spans(); + let spans: BTreeMap<_, _> = spans.iter().map(|s| (s.name.as_ref(), s)).collect(); + + println!("Spans: \n\n{spans:#?}"); + assert_eq!(spans.len(), 10); + + let check_span = |span: &Span, expected_start, expected_end, desc: Option<&str>| { + assert_eq!(span.start, expected_start, "expected start; {span:?}"); + assert_eq!(span.end, expected_end, "expected end; {span:?}"); + assert_eq!(span.metadata.get("desc").map(|x| x.string().unwrap()), desc); + }; + + check_span( + spans["TestExec - exec"], + Some(ts1), + Some(ts4), + Some("TestExec - exec"), + ); + + check_span( + spans["TestExec - child1"], + Some(ts2), + Some(ts4), + Some("TestExec - child1: foo"), + ); + + check_span( + spans["TestExec - child2"], + Some(ts2), + Some(ts3), + Some("TestExec - child2: bar"), + ); + + check_span( + spans["TestExec - child3"], + Some(ts1), + Some(ts3), + Some("TestExec - child3: baz"), + ); + check_span(spans["TestExec - child3 (1)"], Some(ts1), Some(ts3), None); + + check_span( + spans["TestExec - child4"], + Some(ts1), + Some(ts4), + Some("TestExec - child4: bingo"), + ); + + check_span( + spans["TestExec - child5"], + Some(ts1), + Some(ts4), + Some("TestExec - child5: bongo"), + ); + check_span(spans["TestExec - child5 (1)"], Some(ts1), Some(ts2), None); + check_span(spans["TestExec - child5 (2)"], Some(ts2), Some(ts3), None); + check_span(spans["TestExec - child5 (3)"], Some(ts1), Some(ts4), None); + } + + #[test] + fn no_metrics() { + // given execution plan with no metrics, should add notation on metadata + let mut exec = TestExec::new("exec", Default::default()); + exec.metrics = None; + + let traces = TraceBuilder::new(); + send_metrics_to_tracing(Utc::now(), &traces.make_span(), &exec, true); + + let spans = traces.spans(); + assert_eq!(spans.len(), 1); + assert_eq!( + spans[0].metadata.get("missing_statistics"), + Some(&MetaValue::String("true".into())), + "spans: {spans:#?}" + ); + } + + // row count and elapsed compute + #[test] + fn metrics() { + // given execution plan with execution time and compute spread across two partitions (1, and 2) + let mut exec = TestExec::new("exec", Default::default()); + add_output_rows(exec.metrics_mut(), 100, 1); + add_output_rows(exec.metrics_mut(), 200, 2); + + add_elapsed_compute(exec.metrics_mut(), 1000, 1); + add_elapsed_compute(exec.metrics_mut(), 2000, 2); + + let traces = TraceBuilder::new(); + send_metrics_to_tracing(Utc::now(), &traces.make_span(), &exec, true); + + // aggregated metrics should be reported + let spans = traces.spans(); + let spans: BTreeMap<_, _> = spans.iter().map(|s| (s.name.as_ref(), s)).collect(); + + assert_eq!(spans.len(), 3); + + let check_span = |span: &Span, output_row: i64, nanos: i64| { + assert_eq!( + span.metadata.get("output_rows"), + Some(&MetaValue::Int(output_row)), + "span: {span:#?}" + ); + + assert_eq!( + span.metadata.get("elapsed_compute_nanos"), + Some(&MetaValue::Int(nanos)), + "spans: {span:#?}" + ); + }; + + check_span(spans["TestExec - exec"], 300, 3000); + check_span(spans["TestExec - exec (1)"], 100, 1000); + check_span(spans["TestExec - exec (2)"], 200, 2000); + } + + fn add_output_rows(metrics: &mut MetricsSet, output_rows: usize, partition: usize) { + let value = Count::new(); + value.add(output_rows); + + let partition = Some(partition); + metrics.push(Arc::new(Metric::new( + MetricValue::OutputRows(value), + partition, + ))); + } + + fn add_elapsed_compute(metrics: &mut MetricsSet, elapsed_compute: u64, partition: usize) { + let value = Time::new(); + value.add_duration(Duration::from_nanos(elapsed_compute)); + + let partition = Some(partition); + metrics.push(Arc::new(Metric::new( + MetricValue::ElapsedCompute(value), + partition, + ))); + } + + fn make_time_metric_set( + start: Option>, + end: Option>, + partition: Option, + ) -> MetricsSet { + let mut metrics = MetricsSet::new(); + add_time_metrics(&mut metrics, start, end, partition); + metrics + } + + fn add_time_metrics( + metrics: &mut MetricsSet, + start: Option>, + end: Option>, + partition: Option, + ) { + if let Some(start) = start { + let value = make_metrics_timestamp(start); + metrics.push(Arc::new(Metric::new( + MetricValue::StartTimestamp(value), + partition, + ))); + } + + if let Some(end) = end { + let value = make_metrics_timestamp(end); + metrics.push(Arc::new(Metric::new( + MetricValue::EndTimestamp(value), + partition, + ))); + } + } + + fn make_metrics_timestamp(t: DateTime) -> Timestamp { + let timestamp = Timestamp::new(); + timestamp.set(t); + timestamp + } + + /// Encapsulates creating and capturing spans for tests + struct TraceBuilder { + collector: Arc, + } + + impl TraceBuilder { + fn new() -> Self { + Self { + collector: Arc::new(RingBufferTraceCollector::new(10)), + } + } + + // create a new span connected to the collector + fn make_span(&self) -> Span { + SpanContext::new(Arc::clone(&self.collector) as _).child("foo") + } + + /// return all collected spans + fn spans(&self) -> Vec { + self.collector.spans() + } + } + + /// mocked out execution plan we can control metrics + #[derive(Debug)] + struct TestExec { + name: String, + metrics: Option, + children: Vec>, + } + + impl TestExec { + fn new(name: impl Into, metrics: MetricsSet) -> Self { + Self { + name: name.into(), + metrics: Some(metrics), + children: vec![], + } + } + + fn new_child(&mut self, name: impl Into, metrics: MetricsSet) { + self.children.push(Arc::new(Self::new(name, metrics))); + } + + fn metrics_mut(&mut self) -> &mut MetricsSet { + self.metrics.as_mut().unwrap() + } + } + + impl ExecutionPlan for TestExec { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> arrow::datatypes::SchemaRef { + unimplemented!() + } + + fn output_partitioning(&self) -> datafusion::physical_plan::Partitioning { + unimplemented!() + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + unimplemented!() + } + + fn children(&self) -> Vec> { + self.children.clone() + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> datafusion::error::Result> { + unimplemented!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> datafusion::error::Result + { + unimplemented!() + } + + fn statistics(&self) -> Result { + Ok(datafusion::physical_plan::Statistics::new_unknown( + &self.schema(), + )) + } + + fn metrics(&self) -> Option { + self.metrics.clone() + } + } + + impl DisplayAs for TestExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "TestExec - {}", self.name) + } + } + + #[test] + fn test_parsing() { + assert!(bool::from(BooleanFlag::from_str("yes").unwrap())); + assert!(bool::from(BooleanFlag::from_str("Yes").unwrap())); + assert!(bool::from(BooleanFlag::from_str("YES").unwrap())); + + assert!(!bool::from(BooleanFlag::from_str("No").unwrap())); + assert!(!bool::from(BooleanFlag::from_str("FaLse").unwrap())); + + BooleanFlag::from_str("foo").unwrap_err(); + } +} diff --git a/iox_query/src/exec/schema_pivot.rs b/iox_query/src/exec/schema_pivot.rs new file mode 100644 index 0000000..a3e3d3a --- /dev/null +++ b/iox_query/src/exec/schema_pivot.rs @@ -0,0 +1,561 @@ +//! This module contains code for the "SchemaPivot" DataFusion +//! extension plan node +//! +//! A SchemaPivot node takes an arbitrary input like +//! +//! ColA | ColB | ColC +//! ------+------+------ +//! 1 | NULL | NULL +//! 2 | 2 | NULL +//! 3 | 2 | NULL +//! +//! And pivots it to a table with a single string column for any +//! columns that had non null values. +//! +//! non_null_column +//! ----------------- +//! "ColA" +//! "ColB" +//! +//! This operation can be used to implement the tag_keys metadata query + +use std::{ + fmt::{self, Debug}, + sync::Arc, +}; + +use arrow::{ + array::StringArray, + datatypes::{DataType, Field, Schema, SchemaRef}, + error::ArrowError, + record_batch::RecordBatch, +}; +use datafusion::{ + common::{DFSchemaRef, ToDFSchema}, + error::{DataFusionError as Error, Result}, + execution::context::TaskContext, + logical_expr::{Expr, LogicalPlan, UserDefinedLogicalNodeCore}, + physical_plan::{ + expressions::PhysicalSortExpr, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput}, + DisplayFormatType, Distribution, ExecutionPlan, Partitioning, SendableRecordBatchStream, + Statistics, + }, +}; +use datafusion::{error::DataFusionError, physical_plan::DisplayAs}; + +use datafusion_util::{watch::WatchedTask, AdapterStream}; +use observability_deps::tracing::debug; +use tokio::sync::mpsc; +use tokio_stream::StreamExt; + +/// Implements the SchemaPivot operation described in `make_schema_pivot` +#[derive(Hash, PartialEq, Eq)] +pub struct SchemaPivotNode { + input: LogicalPlan, + schema: DFSchemaRef, + // these expressions represent what columns are "used" by this + // node (in this case all of them) -- columns that are not used + // are optimzied away by datafusion. + exprs: Vec, +} + +impl SchemaPivotNode { + pub fn new(input: LogicalPlan) -> Self { + let schema = make_schema_pivot_output_schema(); + + // Form exprs that refer to all of our input columns (so that + // datafusion knows not to opimize them away) + let exprs = input + .schema() + .fields() + .iter() + .map(|field| Expr::Column(field.qualified_column())) + .collect::>(); + + Self { + input, + schema, + exprs, + } + } +} + +impl Debug for SchemaPivotNode { + /// Use explain format for the Debug format. + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.fmt_for_explain(f) + } +} + +impl UserDefinedLogicalNodeCore for SchemaPivotNode { + fn name(&self) -> &str { + "SchemaPivot" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + /// Schema for Pivot is a single string + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + self.exprs.clone() + } + + /// For example: `SchemaPivot` + fn fmt_for_explain(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.name()) + } + + fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + assert_eq!(inputs.len(), 1, "SchemaPivot: input sizes inconistent"); + assert_eq!( + exprs.len(), + self.exprs.len(), + "SchemaPivot: expression sizes inconistent" + ); + Self::new(inputs[0].clone()) + } +} + +// ------ The implementation of SchemaPivot code follows ----- + +/// Create the schema describing the output +fn make_schema_pivot_output_schema() -> DFSchemaRef { + let nullable = false; + Schema::new(vec![Field::new( + "non_null_column", + DataType::Utf8, + nullable, + )]) + .to_dfschema_ref() + .unwrap() +} + +/// Physical operator that implements the SchemaPivot operation against +/// data types +pub struct SchemaPivotExec { + input: Arc, + /// Output schema + schema: SchemaRef, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, +} + +impl SchemaPivotExec { + pub fn new(input: Arc, schema: SchemaRef) -> Self { + Self { + input, + schema, + metrics: ExecutionPlanMetricsSet::new(), + } + } +} + +impl Debug for SchemaPivotExec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "SchemaPivotExec") + } +} + +impl ExecutionPlan for SchemaPivotExec { + fn as_any(&self) -> &(dyn std::any::Any + 'static) { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn output_partitioning(&self) -> Partitioning { + use Partitioning::*; + match self.input.output_partitioning() { + RoundRobinBatch(num_partitions) => RoundRobinBatch(num_partitions), + // as this node transforms the output schema, whatever partitioning + // was present on the input is lost on the output + Hash(_, num_partitions) => UnknownPartitioning(num_partitions), + UnknownPartitioning(num_partitions) => UnknownPartitioning(num_partitions), + } + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn required_input_distribution(&self) -> Vec { + vec![Distribution::UnspecifiedDistribution] + } + + fn children(&self) -> Vec> { + vec![Arc::clone(&self.input)] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + match children.len() { + 1 => Ok(Arc::new(Self { + input: Arc::clone(&children[0]), + schema: Arc::clone(&self.schema), + metrics: ExecutionPlanMetricsSet::new(), + })), + _ => Err(Error::Internal( + "SchemaPivotExec wrong number of children".to_string(), + )), + } + } + + /// Execute one partition and return an iterator over RecordBatch + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + debug!(partition, "Start SchemaPivotExec::execute"); + + if self.output_partitioning().partition_count() <= partition { + return Err(Error::Internal(format!( + "SchemaPivotExec invalid partition {partition}" + ))); + } + + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + let input_schema = self.input.schema(); + let input_stream = self.input.execute(partition, context)?; + + // the operation performed in a separate task which is + // then sent via a channel to the output + let (tx, rx) = mpsc::channel(1); + + let fut = schema_pivot( + input_stream, + input_schema, + self.schema(), + tx.clone(), + baseline_metrics, + ); + + // A second task watches the output of the worker task and reports errors + let handle = WatchedTask::new(fut, vec![tx], "schema_pivot"); + + debug!(partition, "End SchemaPivotExec::execute"); + Ok(AdapterStream::adapt(self.schema(), rx, handle)) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + Ok(Statistics::new_unknown(&self.schema())) + } +} + +impl DisplayAs for SchemaPivotExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "SchemaPivotExec") + } + } + } +} + +// Algorithm: for each column we haven't seen a value for yet, +// check each input row; +// +// Performance Optimizations: Don't continue scaning columns +// if we have already seen a non-null value, and stop early we +// have seen values for all columns. +async fn schema_pivot( + mut input_stream: SendableRecordBatchStream, + input_schema: SchemaRef, + output_schema: SchemaRef, + tx: mpsc::Sender>, + baseline_metrics: BaselineMetrics, +) -> Result<(), DataFusionError> { + let input_fields = input_schema.fields(); + let num_fields = input_fields.len(); + let mut field_indexes_with_seen_values = vec![false; num_fields]; + let mut num_fields_seen_with_values = 0; + + // use a loop so that we release the mutex once we have read each input_batch + let mut keep_searching = true; + while keep_searching { + let input_batch = input_stream.next().await.transpose()?; + let timer = baseline_metrics.elapsed_compute().timer(); + + keep_searching = match input_batch { + Some(input_batch) => { + let num_rows = input_batch.num_rows(); + + for (i, seen_value) in field_indexes_with_seen_values.iter_mut().enumerate() { + // only check fields we haven't seen values for + if !*seen_value { + let column = input_batch.column(i); + + let field_has_values = !column.is_empty() && column.null_count() < num_rows; + + if field_has_values { + *seen_value = true; + num_fields_seen_with_values += 1; + } + } + } + // need to keep searching if there are still some + // fields without values + num_fields_seen_with_values < num_fields + } + // no more input + None => false, + }; + timer.done(); + } + + // now, output a string for each column in the input schema + // that we saw values for + let column_names: StringArray = field_indexes_with_seen_values + .iter() + .enumerate() + .filter_map(|(field_index, has_values)| { + if *has_values { + Some(input_fields[field_index].name()) + } else { + None + } + }) + .map(Some) + .collect(); + + let batch = RecordBatch::try_new(output_schema, vec![Arc::new(column_names)])? + .record_output(&baseline_metrics); + + // and send the result back + tx.send(Ok(batch)) + .await + .map_err(|e| ArrowError::from_external_error(Box::new(e)))?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use crate::exec::stringset::{IntoStringSet, StringSetRef}; + + use super::*; + use arrow::{ + array::{Int64Array, StringArray}, + datatypes::{Field, Schema, SchemaRef}, + }; + use datafusion::physical_plan::memory::MemoryExec; + use datafusion_util::test_execute_partition; + + #[tokio::test] + async fn schema_pivot_exec_all_null() { + let case = SchemaTestCase { + input_batches: &[TestBatch { + a: &[None, None], + b: &[None, None], + }], + expected_output: &[], + }; + assert_eq!( + case.pivot().await, + case.expected_output(), + "TestCase: {case:?}" + ); + } + + #[tokio::test] + async fn schema_pivot_exec_both_non_null() { + let case = SchemaTestCase { + input_batches: &[TestBatch { + a: &[Some(1), None], + b: &[None, Some("foo")], + }], + expected_output: &["A", "B"], + }; + assert_eq!( + case.pivot().await, + case.expected_output(), + "TestCase: {case:?}" + ); + } + + #[tokio::test] + async fn schema_pivot_exec_one_non_null() { + let case = SchemaTestCase { + input_batches: &[TestBatch { + a: &[Some(1), None], + b: &[None, None], + }], + expected_output: &["A"], + }; + assert_eq!( + case.pivot().await, + case.expected_output(), + "TestCase: {case:?}" + ); + } + + #[tokio::test] + async fn schema_pivot_exec_both_non_null_two_record_batches() { + let case = SchemaTestCase { + input_batches: &[ + TestBatch { + a: &[Some(1), None], + b: &[None, None], + }, + TestBatch { + a: &[None, None], + b: &[None, Some("foo")], + }, + ], + expected_output: &["A", "B"], + }; + assert_eq!( + case.pivot().await, + case.expected_output(), + "TestCase: {case:?}" + ); + } + + #[tokio::test] + async fn schema_pivot_exec_one_non_null_in_second_record_batch() { + let case = SchemaTestCase { + input_batches: &[ + TestBatch { + a: &[None, None], + b: &[None, None], + }, + TestBatch { + a: &[None, Some(1), None], + b: &[None, Some("foo"), None], + }, + ], + expected_output: &["A", "B"], + }; + assert_eq!( + case.pivot().await, + case.expected_output(), + "TestCase: {case:?}" + ); + } + + #[tokio::test] + #[should_panic(expected = "SchemaPivotExec invalid partition 1")] + async fn schema_pivot_exec_bad_partition() { + // ensure passing in a bad partition generates a reasonable error + + let pivot = make_schema_pivot(SchemaTestCase::input_schema(), vec![]); + + test_execute_partition(pivot, 1).await; + } + + /// Return a StringSet extracted from the record batch + async fn reader_to_stringset(mut reader: SendableRecordBatchStream) -> StringSetRef { + let mut batches = Vec::new(); + // process the record batches one by one + while let Some(record_batch) = reader.next().await.transpose().expect("reading next batch") + { + batches.push(record_batch) + } + batches + .into_stringset() + .expect("Converted record batch reader into stringset") + } + + /// return a set for testing + fn to_stringset(strs: &[&str]) -> StringSetRef { + let stringset = strs.iter().map(|s| s.to_string()).collect(); + StringSetRef::new(stringset) + } + + /// Create a schema pivot node with a single input + fn make_schema_pivot( + input_schema: SchemaRef, + data: Vec, + ) -> Arc { + let input = make_memory_exec(input_schema, data); + let output_schema = Arc::new(make_schema_pivot_output_schema().as_ref().clone().into()); + Arc::new(SchemaPivotExec::new(input, output_schema)) + } + + /// Create an ExecutionPlan that produces `data` record batches. + fn make_memory_exec(schema: SchemaRef, data: Vec) -> Arc { + let partitions = vec![data]; // single partition + let projection = None; + + let memory_exec = + MemoryExec::try_new(&partitions, schema, projection).expect("creating memory exec"); + + Arc::new(memory_exec) + } + + fn to_string_array(strs: &[Option<&str>]) -> Arc { + let arr: StringArray = strs.iter().collect(); + Arc::new(arr) + } + + // Input schema is (A INT, B STRING) + #[derive(Debug)] + struct TestBatch<'a> { + a: &'a [Option], + b: &'a [Option<&'a str>], + } + + // Input schema is (A INT, B STRING) + #[derive(Debug)] + struct SchemaTestCase<'a> { + // Input record batches, slices of slices (a,b) + input_batches: &'a [TestBatch<'a>], + expected_output: &'a [&'a str], + } + + impl SchemaTestCase<'_> { + fn input_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("A", DataType::Int64, true), + Field::new("B", DataType::Utf8, true), + ])) + } + + /// return expected output, as StringSet + fn expected_output(&self) -> StringSetRef { + to_stringset(self.expected_output) + } + + /// run the input batches through a schema pivot and return the results + /// as a StringSetRef + async fn pivot(&self) -> StringSetRef { + let schema = Self::input_schema(); + + // prepare input + let input_batches = self + .input_batches + .iter() + .map(|test_batch| { + let a_vec = test_batch.a.to_vec(); + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(a_vec)), + to_string_array(test_batch.b), + ], + ) + .expect("Creating new record batch") + }) + .collect::>(); + + let pivot = make_schema_pivot(schema, input_batches); + + let results = test_execute_partition(pivot, 0).await; + + reader_to_stringset(results).await + } + } +} diff --git a/iox_query/src/exec/seriesset.rs b/iox_query/src/exec/seriesset.rs new file mode 100644 index 0000000..7ff393d --- /dev/null +++ b/iox_query/src/exec/seriesset.rs @@ -0,0 +1,89 @@ +//! This module contains the definition of a "SeriesSet" a plan that when run +//! produces rows that can be logically divided into "Series" +//! +//! Specifically, a SeriesSet wraps a "table", and each table is +//! sorted on a set of "tag" columns, meaning the data the series +//! series will be contiguous. +//! +//! For example, the output columns of such a plan would be: +//! (tag col0) (tag col1) ... (tag colN) (field val1) (field val2) ... (field +//! valN) .. (timestamps) +//! +//! Note that the data will come out ordered by the tag keys (ORDER BY +//! (tag col0) (tag col1) ... (tag colN)) +//! +//! NOTE: The InfluxDB classic storage engine not only returns +//! series sorted by the tag values, but the order of the tag columns +//! (and thus the actual sort order) is also lexographically +//! sorted. So for example, if you have `region`, `host`, and +//! `service` as tags, the columns would be ordered `host`, `region`, +//! and `service` as well. + +pub mod converter; +pub mod series; + +use arrow::{self, record_batch::RecordBatch}; + +use std::sync::Arc; + +use super::field::FieldIndexes; + +#[derive(Debug)] +/// Information to map a slice of rows in a [`RecordBatch`] sorted by +/// tags and timestamps to several timeseries that share the same +/// tag keys and timestamps. +/// +/// The information in a [`SeriesSet`] can be used to "unpivot" a +/// [`RecordBatch`] into one or more Time Series as [`series::Series`] +/// +/// For example, given the following set of rows from a [`RecordBatch`] +/// which must be sorted by `(TagA, TagB, time)`: +// +/// TagA | TagB | Field1 | Field2 | time +/// -----+------+--------+--------+------- +/// a | b | 1 | 10 | 100 +/// a | b | 2 | 20 | 200 +/// a | b | 3 | 30 | 300 +/// a | x | 11 | | 100 +/// a | x | 12 | | 200 +/// +/// Would be represented as +/// * `SeriesSet` 1: For {TagA='a', TagB='b'} +/// * `SeriesSet` 2: For {TagA='a', TagB='x'} +/// +/// `SeriesSet` 1 would produce 2 series (one for each field): +/// +/// {_field=Field1, TagA=a, TagB=b} timestamps = {100, 200, 300} values = {1, 2, 3} +/// {_field=Field2, TagA=a, TagB=b} timestamps = {100, 200, 300} values = {100, 200, 300} +/// +/// `SeriesSet` 2 would produce a single series for `Field1` (no +/// series is created for `Field2` because there are no values for +/// `Field2` where TagA=a, and TagB=x) +/// +/// {_field=Field1, TagA=a, TagB=x} timestamps = {100, 200} values = {11, 12} +/// +/// NB: The heavy use of `Arc` is to avoid many duplicated Strings given +/// the the fact that many SeriesSets share the same tag keys and +/// table name. +pub struct SeriesSet { + /// The table name this series came from + pub table_name: Arc, + + /// key = value pairs that define this series + pub tags: Vec<(Arc, Arc)>, + + /// the column index of each "field" of the time series. For + /// example, if there are two field indexes then this series set + /// would result in two distinct series being sent back, one for + /// each field. + pub field_indexes: FieldIndexes, + + // The row in the record batch where the data starts (inclusive) + pub start_row: usize, + + // The number of rows in the record batch that the data goes to + pub num_rows: usize, + + // The underlying record batch data + pub batch: RecordBatch, +} diff --git a/iox_query/src/exec/seriesset/converter.rs b/iox_query/src/exec/seriesset/converter.rs new file mode 100644 index 0000000..81e8384 --- /dev/null +++ b/iox_query/src/exec/seriesset/converter.rs @@ -0,0 +1,1746 @@ +//! This module contains code that "unpivots" annotated +//! [`RecordBatch`]es to [`Series`] and [`Group`]s for output by the +//! storage gRPC interface + +use arrow::{ + self, + array::{downcast_array, Array, BooleanArray, DictionaryArray, StringArray}, + compute, + datatypes::{DataType, Int32Type, SchemaRef}, + record_batch::RecordBatch, +}; +use datafusion::{ + error::DataFusionError, + execution::memory_pool::{proxy::VecAllocExt, MemoryConsumer, MemoryPool, MemoryReservation}, + physical_plan::SendableRecordBatchStream, +}; + +use futures::{ready, Stream, StreamExt}; +use predicate::rpc_predicate::{GROUP_KEY_SPECIAL_START, GROUP_KEY_SPECIAL_STOP}; +use snafu::{OptionExt, Snafu}; +use std::{ + collections::VecDeque, + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use crate::exec::{ + field::{self, FieldColumns, FieldIndexes}, + seriesset::series::Group, +}; + +use super::{ + series::{Either, Series}, + SeriesSet, +}; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("Internal field error while converting series set: {}", source))] + InternalField { source: field::Error }, + + #[snafu(display("Internal error finding grouping colum: {}", column_name))] + FindingGroupColumn { column_name: String }, +} + +pub type Result = std::result::Result; + +// Handles converting record batches into SeriesSets +#[derive(Debug, Default, Copy, Clone)] +pub struct SeriesSetConverter {} + +impl SeriesSetConverter { + /// Convert the results from running a DataFusion plan into the + /// appropriate SeriesSetItems. + /// + /// The results must be in the logical format described in this + /// module's documentation (i.e. ordered by tag keys) + /// + /// table_name: The name of the table + /// + /// tag_columns: The names of the columns that define tags + /// + /// field_columns: The names of the columns which are "fields" + /// + /// it: record batch iterator that produces data in the desired order + pub async fn convert( + &mut self, + table_name: Arc, + tag_columns: Arc>>, + field_columns: FieldColumns, + it: SendableRecordBatchStream, + ) -> Result>, DataFusionError> { + assert_eq!( + tag_columns.as_ref(), + &{ + let mut tmp = tag_columns.as_ref().clone(); + tmp.sort(); + tmp + }, + "Tag column sorted", + ); + + let schema = it.schema(); + + let tag_indexes = FieldIndexes::names_to_indexes(&schema, &tag_columns).map_err(|e| { + DataFusionError::Context( + "Internal field error while converting series set".to_string(), + Box::new(DataFusionError::External(Box::new(e))), + ) + })?; + let field_indexes = + FieldIndexes::from_field_columns(&schema, &field_columns).map_err(|e| { + DataFusionError::Context( + "Internal field error while converting series set".to_string(), + Box::new(DataFusionError::External(Box::new(e))), + ) + })?; + + Ok(SeriesSetConverterStream { + result_buffer: VecDeque::default(), + open_batches: Vec::default(), + need_new_batch: true, + we_finished: false, + schema, + it: Some(it), + tag_indexes, + field_indexes, + table_name, + tag_columns, + }) + } + + /// Returns the row indexes in `batch` where all of the values in the `tag_indexes` columns + /// take on a new value. + /// + /// For example: + /// + /// ```text + /// tags A, B + /// ``` + /// + /// If the input is: + /// + /// A | B | C + /// - | - | - + /// 1 | 2 | x + /// 1 | 2 | y + /// 2 | 2 | z + /// 3 | 3 | q + /// 3 | 3 | r + /// + /// Then this function will return `[3, 4]`: + /// + /// - The row at index 3 has values for A and B (2,2) different than the previous row (1,2). + /// - Similarly the row at index 4 has values (3,3) which are different than (2,2). + /// - However, the row at index 5 has the same values (3,3) so is NOT a transition point + fn compute_changepoints(batch: &RecordBatch, tag_indexes: &[usize]) -> Vec { + let tag_transitions = tag_indexes + .iter() + .map(|&col| Self::compute_transitions(batch, col)) + .collect::>(); + + // no tag columns, emit a single tagset + if tag_transitions.is_empty() { + vec![] + } else { + // OR bitsets together to to find all rows where the + // keyset (values of the tag keys) changes + let mut tag_transitions_it = tag_transitions.into_iter(); + let init = tag_transitions_it.next().expect("not empty"); + let intersections = + tag_transitions_it.fold(init, |a, b| compute::or(&a, &b).expect("or operation")); + + intersections + .iter() + .enumerate() + .filter(|(_idx, mask)| mask.unwrap_or(true)) + .map(|(idx, _mask)| idx) + .collect() + } + } + + /// returns a bitset with all row indexes where the value of the + /// batch `col_idx` changes. Does not include row 0, always includes + /// the last row, `batch.num_rows() - 1` + /// + /// Note: This may return false positives in the presence of dictionaries + /// containing duplicates + fn compute_transitions(batch: &RecordBatch, col_idx: usize) -> BooleanArray { + let num_rows = batch.num_rows(); + + if num_rows == 0 { + return BooleanArray::builder(0).finish(); + } + + let col = batch.column(col_idx); + + let arr = compute::concat(&[ + &{ + let mut b = BooleanArray::builder(1); + b.append_value(false); + b.finish() + }, + &arrow::compute::kernels::cmp::neq( + &col.slice(0, col.len() - 1), + &col.slice(1, col.len() - 1), + ) + .expect("cmp"), + ]) + .expect("concat"); + + downcast_array(&arr) + } + + /// Creates (column_name, column_value) pairs for each column + /// named in `tag_column_name` at the corresponding index + /// `tag_indexes` + fn get_tag_keys( + batch: &RecordBatch, + row: usize, + tag_column_names: &[Arc], + tag_indexes: &[usize], + ) -> Vec<(Arc, Arc)> { + assert_eq!(tag_column_names.len(), tag_indexes.len()); + + let mut out = tag_column_names + .iter() + .zip(tag_indexes) + .filter_map(|(column_name, column_index)| { + let col = batch.column(*column_index); + let tag_value = match col.data_type() { + DataType::Utf8 => { + let col = col.as_any().downcast_ref::().unwrap(); + + if col.is_valid(row) { + Some(col.value(row).to_string()) + } else { + None + } + } + DataType::Dictionary(key, value) + if key.as_ref() == &DataType::Int32 + && value.as_ref() == &DataType::Utf8 => + { + let col = col + .as_any() + .downcast_ref::>() + .expect("Casting column"); + + if col.is_valid(row) { + let key = col.keys().value(row); + let value = col + .values() + .as_any() + .downcast_ref::() + .unwrap() + .value(key as _) + .to_string(); + Some(value) + } else { + None + } + } + _ => unimplemented!( + "Series get_tag_keys not supported for type {:?} in column {:?}", + col.data_type(), + batch.schema().fields()[*column_index] + ), + }; + + tag_value.map(|tag_value| (Arc::clone(column_name), Arc::from(tag_value.as_str()))) + }) + .collect::>(); + + out.shrink_to_fit(); + out + } +} + +struct SeriesSetConverterStream { + /// [`SeriesSet`]s that are ready to be emitted by this stream. + /// + /// These results must always be emitted before doing any additional work. + result_buffer: VecDeque, + + /// Batches of data that have NO change point, i.e. they all belong to the same output set. However we have not yet + /// found the next change point (or the end of the stream) so we need to keep them. + /// + /// We keep a list of batches instead of a giant concatenated batch to avoid `O(n^2)` complexity due to repeated mem-copies. + open_batches: Vec, + + /// If `true`, we need to pull a new batch of `it`. + need_new_batch: bool, + + /// We (i.e. [`SeriesSetConverterStream`]) completed its work. However there might be data available in + /// [`result_buffer`](Self::result_buffer) which must be drained before returning `Ready(None)`. + we_finished: bool, + + /// The schema of the input data. + schema: SchemaRef, + + /// Indexes (within [`schema`](Self::schema)) of the tag columns. + tag_indexes: Vec, + + /// Indexes (within [`schema`](Self::schema)) of the field columns. + field_indexes: FieldIndexes, + + /// Name of the table we're operating on. + /// + /// This is required because this is part of the output [`SeriesSet`]s. + table_name: Arc, + + /// Name of the tag columns. + /// + /// This is kept in addition to [`tag_indexes`](Self::tag_indexes) because it is part of the output [`SeriesSet`]s. + tag_columns: Arc>>, + + /// Input data stream. + /// + /// + /// This may be `None` when the stream was fully drained. We need to remember that fact so we don't pull a + /// finished stream (which may panic). + it: Option, +} + +impl Stream for SeriesSetConverterStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = &mut *self; + + loop { + // drain results + if let Some(sset) = this.result_buffer.pop_front() { + return Poll::Ready(Some(Ok(sset))); + } + + // early exit + if this.we_finished { + return Poll::Ready(None); + } + + // do we need more input data? + if this.need_new_batch { + loop { + match ready!(this + .it + .as_mut() + .expect("need new input but input stream is already drained") + .poll_next_unpin(cx)) + { + Some(Err(e)) => { + return Poll::Ready(Some(Err(e))); + } + Some(Ok(batch)) => { + // skip empty batches (simplifies our code further down below because we can always assume that + // there's at least one row in the batch) + if batch.num_rows() == 0 { + continue; + } + + this.open_batches.push(batch) + } + None => { + this.it = None; + } + } + break; + } + + this.need_new_batch = false; + } + + // do we only have a single batch or do we "overflow" from the last batch? + let (batch_for_changepoints, extra_first_row) = match this.open_batches.len() { + 0 => { + assert!( + this.it.is_none(), + "We have no open batches left, so the input stream should be finished", + ); + this.we_finished = true; + return Poll::Ready(None); + } + 1 => ( + this.open_batches.last().expect("checked length").clone(), + false, + ), + _ => { + // `open_batches` contains at least two batches. The last one was added just from the input stream. + // The prev. one was the end of the "open" interval and all before that belong to the same output + // set (because otherwise we would have flushed them earlier). + let batch_last = &this.open_batches[this.open_batches.len() - 2]; + let batch_current = &this.open_batches[this.open_batches.len() - 1]; + assert!(batch_last.num_rows() > 0); + + let batch = match compute::concat_batches( + &this.schema, + &[ + batch_last.slice(batch_last.num_rows() - 1, 1), + batch_current.clone(), + ], + ) { + Ok(batch) => batch, + Err(e) => { + // internal state is broken, end this stream + this.we_finished = true; + return Poll::Ready(Some(Err(DataFusionError::ArrowError(e, None)))); + } + }; + + (batch, true) + } + }; + + // compute changepoints + let mut changepoints = SeriesSetConverter::compute_changepoints( + &batch_for_changepoints, + &this.tag_indexes, + ); + if this.it.is_none() { + // need to finish last SeriesSet + changepoints.push(batch_for_changepoints.num_rows()); + } + let prev_sizes = this.open_batches[..(this.open_batches.len() - 1)] + .iter() + .map(|b| b.num_rows()) + .sum::(); + let cp_delta = if extra_first_row { + prev_sizes + .checked_sub(1) + .expect("at least one non-empty prev. batch") + } else { + prev_sizes + }; + let changepoints = changepoints + .into_iter() + .map(|x| x + cp_delta) + .collect::>(); + + // already change to "needs data" before we start emission + this.need_new_batch = true; + if this.it.is_none() { + this.we_finished = true; + } + + if !changepoints.is_empty() { + // `batch_for_changepoints` only contains the last batch and the last row of the prev. one. However we + // need to flush ALL rows in `open_batches` (and keep the ones after the last changepoint as a new open + // batch). So concat again. + let batch_for_flush = + match compute::concat_batches(&this.schema, &this.open_batches) { + Ok(batch) => batch, + Err(e) => { + // internal state is broken, end this stream + this.we_finished = true; + return Poll::Ready(Some(Err(DataFusionError::ArrowError(e, None)))); + } + }; + + let last_cp = *changepoints.last().expect("checked length"); + if last_cp == batch_for_flush.num_rows() { + // fully drained open batches + // This can ONLY happen when the input stream finished because `comput_changepoint` never returns + // the last row as changepoint (so we must have manually added that above). + assert!( + this.it.is_none(), + "Fully flushed all open batches but the input stream still has data?!" + ); + this.open_batches.drain(..); + } else { + // need to keep the open bit + // do NOT use `batch` here because it contains data for all open batches, we just need the last one + // (`slice` is zero-copy) + let offset = last_cp.checked_sub(prev_sizes).expect("underflow"); + let last_batch = this.open_batches.last().expect("at least one batch"); + let last_batch = last_batch.slice( + offset, + last_batch + .num_rows() + .checked_sub(offset) + .expect("underflow"), + ); + this.open_batches.drain(..); + this.open_batches.push(last_batch); + } + + // emit each series + let mut start_row: usize = 0; + assert!(this.result_buffer.is_empty()); + this.result_buffer = changepoints + .into_iter() + .map(|end_row| { + let series_set = SeriesSet { + table_name: Arc::clone(&this.table_name), + tags: SeriesSetConverter::get_tag_keys( + &batch_for_flush, + start_row, + &this.tag_columns, + &this.tag_indexes, + ), + field_indexes: this.field_indexes.clone(), + start_row, + num_rows: (end_row - start_row), + // batch clones are super cheap (in contrast to `slice` which has a way higher overhead!) + batch: batch_for_flush.clone(), + }; + + start_row = end_row; + series_set + }) + .collect(); + } + } + } +} + +/// Reorders and groups a sequence of Series is grouped correctly +#[derive(Debug)] +pub struct GroupGenerator { + group_columns: Vec>, + memory_pool: Arc, + collector_buffered_size_max: usize, +} + +impl GroupGenerator { + pub fn new(group_columns: Vec>, memory_pool: Arc) -> Self { + Self::new_with_buffered_size_max( + group_columns, + memory_pool, + Collector::<()>::DEFAULT_ALLOCATION_BUFFER_SIZE, + ) + } + + fn new_with_buffered_size_max( + group_columns: Vec>, + memory_pool: Arc, + collector_buffered_size_max: usize, + ) -> Self { + Self { + group_columns, + memory_pool, + collector_buffered_size_max, + } + } + + /// groups the set of `series` into SeriesOrGroups + /// + /// TODO: make this truly stream-based, see . + pub async fn group( + self, + series: S, + ) -> Result>, DataFusionError> + where + S: Stream> + Send, + { + let series = Box::pin(series); + let mut series = Collector::new( + series, + self.group_columns, + self.memory_pool, + self.collector_buffered_size_max, + ) + .await?; + + // Potential optimization is to skip this sort if we are + // grouping by a prefix of the tags for a single measurement + // + // Another potential optimization is if we are only grouping on + // tag columns is to change the the actual plan output using + // DataFusion to sort the data in the required group (likely + // only possible with a single table) + + // Resort the data according to group key values + series.sort(); + + // now find the groups boundaries and emit the output + let mut last_partition_key_vals: Option>> = None; + + // Note that if there are no group columns, we still need to + // sort by the tag keys, so that the output is sorted by tag + // keys, and thus we can't bail out early here + // + // Interesting, it isn't clear flux requires this ordering, but + // it is what TSM does so we preserve the behavior + let mut output = vec![]; + + // TODO make this more functional (issue is that sometimes the + // loop inserts one item into `output` and sometimes it inserts 2) + for SortableSeries { + series, + tag_vals, + num_partition_keys, + } in series.into_iter() + { + // keep only the values that form the group + let mut partition_key_vals = tag_vals; + partition_key_vals.truncate(num_partition_keys); + + // figure out if we are in a new group (partition key values have changed) + let need_group_start = match &last_partition_key_vals { + None => true, + Some(last_partition_key_vals) => &partition_key_vals != last_partition_key_vals, + }; + + if need_group_start { + last_partition_key_vals = Some(partition_key_vals.clone()); + + let tag_keys = series.tags.iter().map(|tag| Arc::clone(&tag.key)).collect(); + + let group = Group { + tag_keys, + partition_key_vals, + }; + + output.push(group.into()); + } + + output.push(series.into()) + } + + Ok(futures::stream::iter(output).map(Ok)) + } +} + +#[derive(Debug)] +/// Wrapper around a Series that has the values of the group_by columns extracted +struct SortableSeries { + series: Series, + + /// All the tag values, reordered so that the group_columns are first + tag_vals: Vec>, + + /// How many of the first N tag_values are used for the partition key + num_partition_keys: usize, +} + +impl PartialEq for SortableSeries { + fn eq(&self, other: &Self) -> bool { + self.tag_vals.eq(&other.tag_vals) + } +} + +impl Eq for SortableSeries {} + +impl PartialOrd for SortableSeries { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for SortableSeries { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.tag_vals.cmp(&other.tag_vals) + } +} + +impl SortableSeries { + fn try_new(series: Series, group_columns: &[Arc]) -> Result { + // Compute the order of new tag values + let tags = &series.tags; + + // tag_used_set[i] is true if we have used the value in tag_columns[i] + let mut tag_used_set = vec![false; tags.len()]; + + // put the group columns first + // + // Note that this is an O(N^2) algorithm. We are assuming the + // number of tag columns is reasonably small + let mut tag_vals: Vec<_> = group_columns + .iter() + .map(|col| { + tags.iter() + .enumerate() + // Searching for columns linearly is likely to be pretty slow.... + .find(|(_i, tag)| tag.key == *col) + .map(|(i, tag)| { + assert!(!tag_used_set[i], "repeated group column"); + tag_used_set[i] = true; + Arc::clone(&tag.value) + }) + .or_else(|| { + // treat these specially and use value "" to mirror what TSM does + // see https://github.com/influxdata/influxdb_iox/issues/2693#issuecomment-947695442 + // for more details + if col.as_ref() == GROUP_KEY_SPECIAL_START + || col.as_ref() == GROUP_KEY_SPECIAL_STOP + { + Some(Arc::from("")) + } else { + None + } + }) + .context(FindingGroupColumnSnafu { + column_name: col.as_ref(), + }) + }) + .collect::>>()?; + + // Fill in all remaining tags + tag_vals.extend(tags.iter().enumerate().filter_map(|(i, tag)| { + let use_tag = !tag_used_set[i]; + use_tag.then(|| Arc::clone(&tag.value)) + })); + + // safe memory + tag_vals.shrink_to_fit(); + + Ok(Self { + series, + tag_vals, + num_partition_keys: group_columns.len(), + }) + } + + /// Memory usage in bytes, including `self`. + fn size(&self) -> usize { + std::mem::size_of_val(self) + self.series.size() - std::mem::size_of_val(&self.series) + + (std::mem::size_of::>() * self.tag_vals.capacity()) + + self.tag_vals.iter().map(|s| s.len()).sum::() + } +} + +/// [`Future`] that collects [`Series`] objects into a [`SortableSeries`] vector while registering/checking memory +/// allocations with a [`MemoryPool`]. +/// +/// This avoids unbounded memory growth when merging multiple `Series` in memory +struct Collector { + /// The inner stream was fully drained. + inner_done: bool, + + /// This very future finished. + outer_done: bool, + + /// Inner stream. + inner: S, + + /// Group columns. + /// + /// These are required for [`SortableSeries::try_new`]. + group_columns: Vec>, + + /// Already collected objects. + collected: Vec, + + /// Buffered but not-yet-registered allocated size. + /// + /// We use an additional buffer here because in contrast to the normal DataFusion processing, the input stream is + /// NOT batched and we want to avoid costly memory allocations checks with the [`MemoryPool`] for every single element. + buffered_size: usize, + + /// Maximum [buffered size](Self::buffered_size). Decreasing this + /// value causes allocations to be reported to the [`MemoryPool`] + /// more frequently. + buffered_size_max: usize, + + /// Our memory reservation. + mem_reservation: MemoryReservation, +} + +impl Collector { + /// Default maximum [buffered size](Self::buffered_size) before updating [`MemoryPool`] reservation + const DEFAULT_ALLOCATION_BUFFER_SIZE: usize = 1024 * 1024; +} + +impl Collector +where + S: Stream> + Send + Unpin, +{ + fn new( + inner: S, + group_columns: Vec>, + memory_pool: Arc, + buffered_size_max: usize, + ) -> Self { + let mem_reservation = MemoryConsumer::new("SeriesSet Collector").register(&memory_pool); + + Self { + inner_done: false, + outer_done: false, + inner, + group_columns, + collected: Vec::with_capacity(0), + buffered_size: 0, + buffered_size_max, + mem_reservation, + } + } + + /// Registers all `self.buffered_size` with the MemoryPool, + /// resetting self.buffered_size to zero. Returns an error if new + /// memory can not be allocated from the pool. + fn alloc(&mut self) -> Result<(), DataFusionError> { + let bytes = std::mem::take(&mut self.buffered_size); + self.mem_reservation.try_grow(bytes) + } +} + +impl Future for Collector +where + S: Stream> + Send + Unpin, +{ + type Output = Result, DataFusionError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = &mut *self; + + loop { + assert!(!this.outer_done); + // if the underlying stream is drained and the allocation future is ready (see above), we can finalize this future + if this.inner_done { + this.outer_done = true; + return Poll::Ready(Ok(std::mem::take(&mut this.collected))); + } + + match ready!(this.inner.poll_next_unpin(cx)) { + Some(Ok(series)) => match SortableSeries::try_new(series, &this.group_columns) { + Ok(series) => { + // Note: the size of `SortableSeries` itself is already included in the vector allocation + this.buffered_size += series.size() - std::mem::size_of_val(&series); + this.collected + .push_accounted(series, &mut this.buffered_size); + + // should we clear our allocation buffer? + if this.buffered_size > this.buffered_size_max { + if let Err(e) = this.alloc() { + return Poll::Ready(Err(e)); + } + continue; + } + } + Err(e) => { + // poison this future + this.outer_done = true; + return Poll::Ready(Err(DataFusionError::External(Box::new(e)))); + } + }, + Some(Err(e)) => { + // poison this future + this.outer_done = true; + return Poll::Ready(Err(e)); + } + None => { + // underlying stream drained. now register the final allocation and then we're done + this.inner_done = true; + if this.buffered_size > 0 { + if let Err(e) = this.alloc() { + return Poll::Ready(Err(e)); + } + } + continue; + } + } + } + } +} + +#[cfg(test)] +mod tests { + use arrow::{ + array::{ArrayRef, Float64Array, Int64Array, TimestampNanosecondArray}, + csv, + datatypes::DataType, + datatypes::Field, + datatypes::{Schema, SchemaRef}, + record_batch::RecordBatch, + }; + use arrow_util::assert_batches_eq; + use assert_matches::assert_matches; + use datafusion::execution::memory_pool::GreedyMemoryPool; + use datafusion_util::{stream_from_batch, stream_from_batches, stream_from_schema}; + use futures::TryStreamExt; + use itertools::Itertools; + use test_helpers::str_vec_to_arc_vec; + + use crate::exec::seriesset::series::{Batch, Data, Tag}; + + use super::*; + + #[tokio::test] + async fn test_convert_empty() { + let schema = test_schema(); + let empty_iterator = stream_from_schema(schema); + + let table_name = "foo"; + let tag_columns = []; + let field_columns = []; + + let results = convert(table_name, &tag_columns, &field_columns, empty_iterator).await; + assert_eq!(results.len(), 0); + } + + fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("tag_a", DataType::Utf8, true), + Field::new("tag_b", DataType::Utf8, true), + Field::new("float_field", DataType::Float64, true), + Field::new("int_field", DataType::Int64, true), + Field::new("time", DataType::Int64, false), + ])) + } + + #[tokio::test] + async fn test_convert_single_series_no_tags() { + // single series + let schema = test_schema(); + let inputs = parse_to_iterators(schema, &["one,ten,10.0,1,1000", "one,ten,10.1,2,2000"]); + for (i, input) in inputs.into_iter().enumerate() { + println!("Stream {i}"); + + let table_name = "foo"; + let tag_columns = []; + let field_columns = ["float_field"]; + let results = convert(table_name, &tag_columns, &field_columns, input).await; + + assert_eq!(results.len(), 1); + + assert_series_set( + &results[0], + "foo", + [], + FieldIndexes::from_timestamp_and_value_indexes(4, &[2]), + [ + "+-------+-------+-------------+-----------+------+", + "| tag_a | tag_b | float_field | int_field | time |", + "+-------+-------+-------------+-----------+------+", + "| one | ten | 10.0 | 1 | 1000 |", + "| one | ten | 10.1 | 2 | 2000 |", + "+-------+-------+-------------+-----------+------+", + ], + ); + } + } + + #[tokio::test] + async fn test_convert_single_series_no_tags_nulls() { + // single series + let schema = test_schema(); + + let inputs = parse_to_iterators(schema, &["one,ten,10.0,,1000", "one,ten,10.1,,2000"]); + + // send no values in the int_field colum + for (i, input) in inputs.into_iter().enumerate() { + println!("Stream {i}"); + + let table_name = "foo"; + let tag_columns = []; + let field_columns = ["float_field"]; + let results = convert(table_name, &tag_columns, &field_columns, input).await; + + assert_eq!(results.len(), 1); + + assert_series_set( + &results[0], + "foo", + [], + FieldIndexes::from_timestamp_and_value_indexes(4, &[2]), + [ + "+-------+-------+-------------+-----------+------+", + "| tag_a | tag_b | float_field | int_field | time |", + "+-------+-------+-------------+-----------+------+", + "| one | ten | 10.0 | | 1000 |", + "| one | ten | 10.1 | | 2000 |", + "+-------+-------+-------------+-----------+------+", + ], + ); + } + } + + #[tokio::test] + async fn test_convert_single_series_one_tag() { + // single series + let schema = test_schema(); + let inputs = parse_to_iterators(schema, &["one,ten,10.0,1,1000", "one,ten,10.1,2,2000"]); + + for (i, input) in inputs.into_iter().enumerate() { + println!("Stream {i}"); + + // test with one tag column, one series + let table_name = "bar"; + let tag_columns = ["tag_a"]; + let field_columns = ["float_field"]; + let results = convert(table_name, &tag_columns, &field_columns, input).await; + + assert_eq!(results.len(), 1); + + assert_series_set( + &results[0], + "bar", + [("tag_a", "one")], + FieldIndexes::from_timestamp_and_value_indexes(4, &[2]), + [ + "+-------+-------+-------------+-----------+------+", + "| tag_a | tag_b | float_field | int_field | time |", + "+-------+-------+-------------+-----------+------+", + "| one | ten | 10.0 | 1 | 1000 |", + "| one | ten | 10.1 | 2 | 2000 |", + "+-------+-------+-------------+-----------+------+", + ], + ); + } + } + + #[tokio::test] + async fn test_convert_single_series_one_tag_more_rows() { + // single series + let schema = test_schema(); + let inputs = parse_to_iterators( + schema, + &[ + "one,ten,10.0,1,1000", + "one,ten,10.1,2,2000", + "one,ten,10.2,3,3000", + ], + ); + + for (i, input) in inputs.into_iter().enumerate() { + println!("Stream {i}"); + + // test with one tag column, one series + let table_name = "bar"; + let tag_columns = ["tag_a"]; + let field_columns = ["float_field"]; + let results = convert(table_name, &tag_columns, &field_columns, input).await; + + assert_eq!(results.len(), 1); + + assert_series_set( + &results[0], + "bar", + [("tag_a", "one")], + FieldIndexes::from_timestamp_and_value_indexes(4, &[2]), + [ + "+-------+-------+-------------+-----------+------+", + "| tag_a | tag_b | float_field | int_field | time |", + "+-------+-------+-------------+-----------+------+", + "| one | ten | 10.0 | 1 | 1000 |", + "| one | ten | 10.1 | 2 | 2000 |", + "| one | ten | 10.2 | 3 | 3000 |", + "+-------+-------+-------------+-----------+------+", + ], + ); + } + } + + #[tokio::test] + async fn test_convert_one_tag_multi_series() { + let schema = test_schema(); + + let inputs = parse_to_iterators( + schema, + &[ + "one,ten,10.0,1,1000", + "one,ten,10.1,2,2000", + "one,eleven,10.1,3,3000", + "two,eleven,10.2,4,4000", + "two,eleven,10.3,5,5000", + ], + ); + + for (i, input) in inputs.into_iter().enumerate() { + println!("Stream {i}"); + + let table_name = "foo"; + let tag_columns = ["tag_a"]; + let field_columns = ["int_field"]; + let results = convert(table_name, &tag_columns, &field_columns, input).await; + + assert_eq!(results.len(), 2); + + assert_series_set( + &results[0], + "foo", + [("tag_a", "one")], + FieldIndexes::from_timestamp_and_value_indexes(4, &[3]), + [ + "+-------+--------+-------------+-----------+------+", + "| tag_a | tag_b | float_field | int_field | time |", + "+-------+--------+-------------+-----------+------+", + "| one | ten | 10.0 | 1 | 1000 |", + "| one | ten | 10.1 | 2 | 2000 |", + "| one | eleven | 10.1 | 3 | 3000 |", + "+-------+--------+-------------+-----------+------+", + ], + ); + assert_series_set( + &results[1], + "foo", + [("tag_a", "two")], + FieldIndexes::from_timestamp_and_value_indexes(4, &[3]), + [ + "+-------+--------+-------------+-----------+------+", + "| tag_a | tag_b | float_field | int_field | time |", + "+-------+--------+-------------+-----------+------+", + "| two | eleven | 10.2 | 4 | 4000 |", + "| two | eleven | 10.3 | 5 | 5000 |", + "+-------+--------+-------------+-----------+------+", + ], + ); + } + } + + // two tag columns, three series + #[tokio::test] + async fn test_convert_two_tag_multi_series() { + let schema = test_schema(); + + let inputs = parse_to_iterators( + schema, + &[ + "one,ten,10.0,1,1000", + "one,ten,10.1,2,2000", + "one,eleven,10.1,3,3000", + "two,eleven,10.2,4,4000", + "two,eleven,10.3,5,5000", + ], + ); + + for (i, input) in inputs.into_iter().enumerate() { + println!("Stream {i}"); + + let table_name = "foo"; + let tag_columns = ["tag_a", "tag_b"]; + let field_columns = ["int_field"]; + let results = convert(table_name, &tag_columns, &field_columns, input).await; + + assert_eq!(results.len(), 3); + + assert_series_set( + &results[0], + "foo", + [("tag_a", "one"), ("tag_b", "ten")], + FieldIndexes::from_timestamp_and_value_indexes(4, &[3]), + [ + "+-------+-------+-------------+-----------+------+", + "| tag_a | tag_b | float_field | int_field | time |", + "+-------+-------+-------------+-----------+------+", + "| one | ten | 10.0 | 1 | 1000 |", + "| one | ten | 10.1 | 2 | 2000 |", + "+-------+-------+-------------+-----------+------+", + ], + ); + assert_series_set( + &results[1], + "foo", + [("tag_a", "one"), ("tag_b", "eleven")], + FieldIndexes::from_timestamp_and_value_indexes(4, &[3]), + [ + "+-------+--------+-------------+-----------+------+", + "| tag_a | tag_b | float_field | int_field | time |", + "+-------+--------+-------------+-----------+------+", + "| one | eleven | 10.1 | 3 | 3000 |", + "+-------+--------+-------------+-----------+------+", + ], + ); + assert_series_set( + &results[2], + "foo", + [("tag_a", "two"), ("tag_b", "eleven")], + FieldIndexes::from_timestamp_and_value_indexes(4, &[3]), + [ + "+-------+--------+-------------+-----------+------+", + "| tag_a | tag_b | float_field | int_field | time |", + "+-------+--------+-------------+-----------+------+", + "| two | eleven | 10.2 | 4 | 4000 |", + "| two | eleven | 10.3 | 5 | 5000 |", + "+-------+--------+-------------+-----------+------+", + ], + ); + } + } + + #[tokio::test] + async fn test_convert_two_tag_with_null_multi_series() { + let tag_a = StringArray::from(vec!["one", "one", "one"]); + let tag_b = StringArray::from(vec![Some("ten"), Some("ten"), None]); + let float_field = Float64Array::from(vec![10.0, 10.1, 10.1]); + let int_field = Int64Array::from(vec![1, 2, 3]); + let time = TimestampNanosecondArray::from(vec![1000, 2000, 3000]); + + let batch = RecordBatch::try_from_iter_with_nullable(vec![ + ("tag_a", Arc::new(tag_a) as ArrayRef, true), + ("tag_b", Arc::new(tag_b), true), + ("float_field", Arc::new(float_field), true), + ("int_field", Arc::new(int_field), true), + ("time", Arc::new(time), false), + ]) + .unwrap(); + + // Input has one row that has no value (NULL value) for tag_b, which is its own series + let input = stream_from_batch(batch.schema(), batch); + + let table_name = "foo"; + let tag_columns = ["tag_a", "tag_b"]; + let field_columns = ["int_field"]; + let results = convert(table_name, &tag_columns, &field_columns, input).await; + + assert_eq!(results.len(), 2); + + assert_series_set( + &results[0], + "foo", + [("tag_a", "one"), ("tag_b", "ten")], + FieldIndexes::from_timestamp_and_value_indexes(4, &[3]), + [ + "+-------+-------+-------------+-----------+-----------------------------+", + "| tag_a | tag_b | float_field | int_field | time |", + "+-------+-------+-------------+-----------+-----------------------------+", + "| one | ten | 10.0 | 1 | 1970-01-01T00:00:00.000001Z |", + "| one | ten | 10.1 | 2 | 1970-01-01T00:00:00.000002Z |", + "+-------+-------+-------------+-----------+-----------------------------+", + ], + ); + assert_series_set( + &results[1], + "foo", + [("tag_a", "one")], // note no value for tag_b, only one tag + FieldIndexes::from_timestamp_and_value_indexes(4, &[3]), + [ + "+-------+-------+-------------+-----------+-----------------------------+", + "| tag_a | tag_b | float_field | int_field | time |", + "+-------+-------+-------------+-----------+-----------------------------+", + "| one | | 10.1 | 3 | 1970-01-01T00:00:00.000003Z |", + "+-------+-------+-------------+-----------+-----------------------------+", + ], + ); + } + + /// Test helper: run conversion and return a Vec + pub async fn convert<'a>( + table_name: &'a str, + tag_columns: &'a [&'a str], + field_columns: &'a [&'a str], + it: SendableRecordBatchStream, + ) -> Vec { + let mut converter = SeriesSetConverter::default(); + + let table_name = Arc::from(table_name); + let tag_columns = Arc::new(str_vec_to_arc_vec(tag_columns)); + let field_columns = FieldColumns::from(field_columns); + + converter + .convert(table_name, tag_columns, field_columns, it) + .await + .expect("Conversion happened without error") + .try_collect() + .await + .expect("Conversion happened without error") + } + + /// Test helper: parses the csv content into a single record batch arrow + /// arrays columnar ArrayRef according to the schema + fn parse_to_record_batch(schema: SchemaRef, data: &str) -> RecordBatch { + if data.is_empty() { + return RecordBatch::new_empty(schema); + } + + let batch_size = 1000; + let mut reader = csv::ReaderBuilder::new(schema) + .with_batch_size(batch_size) + .build_buffered(data.as_bytes()) + .unwrap(); + + let first_batch = reader.next().expect("Reading first batch"); + assert!( + first_batch.is_ok(), + "Can not parse record batch from csv: {first_batch:?}" + ); + assert!( + reader.next().is_none(), + "Unexpected batch while parsing csv" + ); + + println!("batch: \n{first_batch:#?}"); + + first_batch.unwrap() + } + + /// Parses a set of CSV lines into several `RecordBatchStream`s of varying sizes + /// + /// For example, with three input lines: + /// line1 + /// line2 + /// line3 + /// + /// This will produce two output streams: + /// Stream1: (line1), (line2), (line3) + /// Stream2: (line1, line2), (line3) + fn parse_to_iterators(schema: SchemaRef, lines: &[&str]) -> Vec { + split_lines(lines) + .into_iter() + .map(|batches| { + let batches = batches + .into_iter() + .map(|chunk| parse_to_record_batch(Arc::clone(&schema), &chunk)) + .collect::>(); + + stream_from_batches(Arc::clone(&schema), batches) + }) + .collect() + } + + fn split_lines(lines: &[&str]) -> Vec> { + println!("** Input data:\n{lines:#?}\n\n"); + if lines.is_empty() { + return vec![vec![], vec![String::from("")]]; + } + + // potential split points for batches + // we keep each split point twice so we may also produce empty batches + let n_lines = lines.len(); + let mut split_points = (0..=n_lines).chain(0..=n_lines).collect::>(); + split_points.sort(); + + let mut split_point_sets = split_points + .into_iter() + .powerset() + .map(|mut split_points| { + split_points.sort(); + + // ensure that "begin" and "end" are always split points + if split_points.first() != Some(&0) { + split_points.insert(0, 0); + } + if split_points.last() != Some(&n_lines) { + split_points.push(n_lines); + } + + split_points + }) + .collect::>(); + split_point_sets.sort(); + + let variants = split_point_sets + .into_iter() + .unique() + .map(|split_points| { + let batches = split_points + .into_iter() + .tuple_windows() + .map(|(begin, end)| lines[begin..end].join("\n")) + .collect::>(); + + // stream from those batches + assert!(!batches.is_empty()); + batches + }) + .collect::>(); + + assert!(!variants.is_empty()); + variants + } + + #[test] + fn test_split_lines() { + assert_eq!(split_lines(&[]), vec![vec![], vec![String::from("")],],); + + assert_eq!( + split_lines(&["foo"]), + vec![ + vec![String::from(""), String::from("foo")], + vec![String::from(""), String::from("foo"), String::from("")], + vec![String::from("foo")], + vec![String::from("foo"), String::from("")], + ], + ); + + assert_eq!( + split_lines(&["foo", "bar"]), + vec![ + vec![ + String::from(""), + String::from("foo"), + String::from(""), + String::from("bar") + ], + vec![ + String::from(""), + String::from("foo"), + String::from(""), + String::from("bar"), + String::from("") + ], + vec![String::from(""), String::from("foo"), String::from("bar")], + vec![ + String::from(""), + String::from("foo"), + String::from("bar"), + String::from("") + ], + vec![String::from(""), String::from("foo\nbar")], + vec![String::from(""), String::from("foo\nbar"), String::from("")], + vec![String::from("foo"), String::from(""), String::from("bar")], + vec![ + String::from("foo"), + String::from(""), + String::from("bar"), + String::from("") + ], + vec![String::from("foo"), String::from("bar")], + vec![String::from("foo"), String::from("bar"), String::from("")], + vec![String::from("foo\nbar")], + vec![String::from("foo\nbar"), String::from("")], + ], + ); + + assert_eq!( + split_lines(&["foo", "bar", "xxx"]), + vec![ + vec![ + String::from(""), + String::from("foo"), + String::from(""), + String::from("bar"), + String::from(""), + String::from("xxx") + ], + vec![ + String::from(""), + String::from("foo"), + String::from(""), + String::from("bar"), + String::from(""), + String::from("xxx"), + String::from("") + ], + vec![ + String::from(""), + String::from("foo"), + String::from(""), + String::from("bar"), + String::from("xxx") + ], + vec![ + String::from(""), + String::from("foo"), + String::from(""), + String::from("bar"), + String::from("xxx"), + String::from("") + ], + vec![ + String::from(""), + String::from("foo"), + String::from(""), + String::from("bar\nxxx") + ], + vec![ + String::from(""), + String::from("foo"), + String::from(""), + String::from("bar\nxxx"), + String::from("") + ], + vec![ + String::from(""), + String::from("foo"), + String::from("bar"), + String::from(""), + String::from("xxx") + ], + vec![ + String::from(""), + String::from("foo"), + String::from("bar"), + String::from(""), + String::from("xxx"), + String::from("") + ], + vec![ + String::from(""), + String::from("foo"), + String::from("bar"), + String::from("xxx") + ], + vec![ + String::from(""), + String::from("foo"), + String::from("bar"), + String::from("xxx"), + String::from("") + ], + vec![ + String::from(""), + String::from("foo"), + String::from("bar\nxxx") + ], + vec![ + String::from(""), + String::from("foo"), + String::from("bar\nxxx"), + String::from("") + ], + vec![ + String::from(""), + String::from("foo\nbar"), + String::from(""), + String::from("xxx") + ], + vec![ + String::from(""), + String::from("foo\nbar"), + String::from(""), + String::from("xxx"), + String::from("") + ], + vec![ + String::from(""), + String::from("foo\nbar"), + String::from("xxx") + ], + vec![ + String::from(""), + String::from("foo\nbar"), + String::from("xxx"), + String::from("") + ], + vec![String::from(""), String::from("foo\nbar\nxxx")], + vec![ + String::from(""), + String::from("foo\nbar\nxxx"), + String::from("") + ], + vec![ + String::from("foo"), + String::from(""), + String::from("bar"), + String::from(""), + String::from("xxx") + ], + vec![ + String::from("foo"), + String::from(""), + String::from("bar"), + String::from(""), + String::from("xxx"), + String::from("") + ], + vec![ + String::from("foo"), + String::from(""), + String::from("bar"), + String::from("xxx") + ], + vec![ + String::from("foo"), + String::from(""), + String::from("bar"), + String::from("xxx"), + String::from("") + ], + vec![ + String::from("foo"), + String::from(""), + String::from("bar\nxxx") + ], + vec![ + String::from("foo"), + String::from(""), + String::from("bar\nxxx"), + String::from("") + ], + vec![ + String::from("foo"), + String::from("bar"), + String::from(""), + String::from("xxx") + ], + vec![ + String::from("foo"), + String::from("bar"), + String::from(""), + String::from("xxx"), + String::from("") + ], + vec![ + String::from("foo"), + String::from("bar"), + String::from("xxx") + ], + vec![ + String::from("foo"), + String::from("bar"), + String::from("xxx"), + String::from("") + ], + vec![String::from("foo"), String::from("bar\nxxx")], + vec![ + String::from("foo"), + String::from("bar\nxxx"), + String::from("") + ], + vec![ + String::from("foo\nbar"), + String::from(""), + String::from("xxx") + ], + vec![ + String::from("foo\nbar"), + String::from(""), + String::from("xxx"), + String::from("") + ], + vec![String::from("foo\nbar"), String::from("xxx")], + vec![ + String::from("foo\nbar"), + String::from("xxx"), + String::from("") + ], + vec![String::from("foo\nbar\nxxx")], + vec![String::from("foo\nbar\nxxx"), String::from("")] + ] + ); + } + + #[tokio::test] + async fn test_group_generator_mem_limit() { + let memory_pool = Arc::new(GreedyMemoryPool::new(1)) as _; + + let ggen = GroupGenerator::new(vec![Arc::from("g")], memory_pool); + let input = futures::stream::iter([Ok(Series { + tags: vec![Tag { + key: Arc::from("g"), + value: Arc::from("x"), + }], + data: Data::FloatPoints(vec![Batch { + timestamps: vec![], + values: vec![], + }]), + })]); + let err = match ggen.group(input).await { + Ok(stream) => stream.try_collect::>().await.unwrap_err(), + Err(e) => e, + }; + assert_matches!(err, DataFusionError::ResourcesExhausted(_)); + } + + #[tokio::test] + async fn test_group_generator_no_mem_limit() { + let memory_pool = Arc::new(GreedyMemoryPool::new(usize::MAX)) as _; + // use a generator w/ a low buffered allocation to force multiple `alloc` calls + let ggen = GroupGenerator::new_with_buffered_size_max(vec![Arc::from("g")], memory_pool, 1); + let input = futures::stream::iter([ + Ok(Series { + tags: vec![Tag { + key: Arc::from("g"), + value: Arc::from("x"), + }], + data: Data::IntegerPoints(vec![Batch { + timestamps: vec![1], + values: vec![1], + }]), + }), + Ok(Series { + tags: vec![Tag { + key: Arc::from("g"), + value: Arc::from("y"), + }], + data: Data::IntegerPoints(vec![Batch { + timestamps: vec![2], + values: vec![2], + }]), + }), + Ok(Series { + tags: vec![Tag { + key: Arc::from("g"), + value: Arc::from("x"), + }], + data: Data::IntegerPoints(vec![Batch { + timestamps: vec![3], + values: vec![3], + }]), + }), + Ok(Series { + tags: vec![Tag { + key: Arc::from("g"), + value: Arc::from("x"), + }], + data: Data::IntegerPoints(vec![Batch { + timestamps: vec![4], + values: vec![4], + }]), + }), + ]); + let actual = ggen + .group(input) + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + let expected = vec![ + Either::Group(Group { + tag_keys: vec![Arc::from("g")], + partition_key_vals: vec![Arc::from("x")], + }), + Either::Series(Series { + tags: vec![Tag { + key: Arc::from("g"), + value: Arc::from("x"), + }], + data: Data::IntegerPoints(vec![Batch { + timestamps: vec![1], + values: vec![1], + }]), + }), + Either::Series(Series { + tags: vec![Tag { + key: Arc::from("g"), + value: Arc::from("x"), + }], + data: Data::IntegerPoints(vec![Batch { + timestamps: vec![3], + values: vec![3], + }]), + }), + Either::Series(Series { + tags: vec![Tag { + key: Arc::from("g"), + value: Arc::from("x"), + }], + data: Data::IntegerPoints(vec![Batch { + timestamps: vec![4], + values: vec![4], + }]), + }), + Either::Group(Group { + tag_keys: vec![Arc::from("g")], + partition_key_vals: vec![Arc::from("y")], + }), + Either::Series(Series { + tags: vec![Tag { + key: Arc::from("g"), + value: Arc::from("y"), + }], + data: Data::IntegerPoints(vec![Batch { + timestamps: vec![2], + values: vec![2], + }]), + }), + ]; + assert_eq!(actual, expected); + } + + fn assert_series_set( + set: &SeriesSet, + table_name: &'static str, + tags: [(&'static str, &'static str); N], + field_indexes: FieldIndexes, + data: [&'static str; M], + ) { + assert_eq!(set.table_name.as_ref(), table_name); + + let set_tags = set + .tags + .iter() + .map(|(a, b)| (a.as_ref(), b.as_ref())) + .collect::>(); + assert_eq!(set_tags.as_slice(), tags); + + assert_eq!(set.field_indexes, field_indexes); + + assert_batches_eq!(data, &[set.batch.slice(set.start_row, set.num_rows)]); + } +} diff --git a/iox_query/src/exec/seriesset/series.rs b/iox_query/src/exec/seriesset/series.rs new file mode 100644 index 0000000..4f12c5d --- /dev/null +++ b/iox_query/src/exec/seriesset/series.rs @@ -0,0 +1,775 @@ +//! This module contains the native Rust version of the Data frames +//! that are sent back in the storage gRPC format. + +use std::{fmt, sync::Arc}; + +use arrow::{ + array::{ + Array, ArrayRef, BooleanArray, Float64Array, Int64Array, StringArray, + TimestampNanosecondArray, UInt64Array, + }, + compute, + datatypes::DataType as ArrowDataType, +}; +use predicate::rpc_predicate::{FIELD_COLUMN_NAME, MEASUREMENT_COLUMN_NAME}; + +use crate::exec::{field::FieldIndex, seriesset::SeriesSet}; +use snafu::Snafu; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("Unsupported data type while translating to Frames: {}", data_type))] + UnsupportedDataType { data_type: ArrowDataType }, + + #[snafu(display("Unsupported field data while translating to Frames: {}", data_type))] + UnsupportedFieldType { data_type: ArrowDataType }, +} + +pub type Result = std::result::Result; + +/// A name=value pair used to represent a series's tag +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Tag { + pub key: Arc, + pub value: Arc, +} + +impl Tag { + /// Memory usage in bytes, including `self`. + pub fn size(&self) -> usize { + std::mem::size_of_val(self) + self.key.len() + self.value.len() + } +} + +impl fmt::Display for Tag { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}={}", self.key, self.value) + } +} + +/// Represents a single logical TimeSeries +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Series { + /// key = value pairs that define this series + /// (including the _measurement and _field that correspond to table name and column name) + pub tags: Vec, + + /// The raw data for this series + pub data: Data, +} + +impl Series { + pub fn num_batches(&self) -> usize { + match &self.data { + Data::FloatPoints(batches) => batches.len(), + Data::IntegerPoints(batches) => batches.len(), + Data::UnsignedPoints(batches) => batches.len(), + Data::BooleanPoints(batches) => batches.len(), + Data::StringPoints(batches) => batches.len(), + } + } + + /// Memory usage in bytes, including `self`. + pub fn size(&self) -> usize { + std::mem::size_of_val(self) + + (std::mem::size_of::() * self.tags.capacity()) + + self + .tags + .iter() + .map(|tag| tag.size() - std::mem::size_of_val(tag)) + .sum::() + + self.data.size() + - std::mem::size_of_val(&self.data) + } +} + +impl fmt::Display for Series { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Series tags={{")?; + let mut first = true; + self.tags.iter().try_for_each(|tag| { + if !first { + write!(f, ", ")?; + } else { + first = false; + } + write!(f, "{tag}") + })?; + writeln!(f, "}}")?; + write!(f, " {}", self.data) + } +} + +/// Typed data for a particular timeseries +#[derive(Clone, Debug)] +pub enum Data { + FloatPoints(Vec>), + IntegerPoints(Vec>), + UnsignedPoints(Vec>), + BooleanPoints(Vec>), + StringPoints(Vec>), +} + +impl Data { + /// Memory usage in bytes, including `self`. + pub fn size(&self) -> usize { + let data_sz: usize = match self { + Self::FloatPoints(points_vec) => points_vec.iter().map(|ps| ps.size()).sum(), + Self::IntegerPoints(points_vec) => points_vec.iter().map(|ps| ps.size()).sum(), + Self::UnsignedPoints(points_vec) => points_vec.iter().map(|ps| ps.size()).sum(), + Self::BooleanPoints(points_vec) => points_vec.iter().map(|ps| ps.size()).sum(), + Self::StringPoints(points_vec) => points_vec.iter().map(|ps| ps.size()).sum(), + }; + std::mem::size_of_val(self) + data_sz + } +} + +impl PartialEq for Data { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::FloatPoints(l_batches), Self::FloatPoints(r_batches)) => l_batches == r_batches, + (Self::IntegerPoints(l_batches), Self::IntegerPoints(r_batches)) => { + l_batches == r_batches + } + (Self::UnsignedPoints(l_batches), Self::UnsignedPoints(r_batches)) => { + l_batches == r_batches + } + (Self::BooleanPoints(l_batches), Self::BooleanPoints(r_batches)) => { + l_batches == r_batches + } + (Self::StringPoints(l_batches), Self::StringPoints(r_batches)) => { + l_batches == r_batches + } + _ => false, + } + } +} + +impl Eq for Data {} + +/// Returns size of given vector of primitive types in bytes, EXCLUDING `vec` itself. +fn primitive_vec_size(vec: &Vec) -> usize { + std::mem::size_of::() * vec.capacity() +} + +impl fmt::Display for Data { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::FloatPoints(batches) => write!(f, "FloatPoints batches: {batches:?}"), + Self::IntegerPoints(batches) => write!(f, "IntegerPoints batches: {batches:?}"), + Self::UnsignedPoints(batches) => write!(f, "UnsignedPoints batches: {batches:?}"), + Self::BooleanPoints(batches) => write!(f, "BooleanPoints batches: {batches:?}"), + Self::StringPoints(batches) => write!(f, "StringPoints batches: {batches:?}"), + } + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Batch { + pub timestamps: Vec, + pub values: Vec, +} + +impl Batch { + fn size(&self) -> usize { + std::mem::size_of_val(self) + + primitive_vec_size(&self.timestamps) + + primitive_vec_size(&self.values) + } +} + +impl SeriesSet { + /// Returns true if the array is entirely null between start_row and + /// start_row+num_rows + fn is_all_null(arr: &ArrayRef) -> bool { + arr.null_count() == arr.len() + } + + pub fn is_timestamp_all_null(&self) -> bool { + self.field_indexes.iter().all(|field_index| { + let array = self.batch.column(field_index.timestamp_index); + Self::is_all_null(array) + }) + } + + pub fn try_into_series(self, batch_size: usize) -> Result> { + self.field_indexes + .iter() + .filter_map(|index| self.field_to_series(index, batch_size).transpose()) + .collect() + } + + // Convert and append the values from a single field to a Series + // appended to `frames` + fn field_to_series(&self, index: &FieldIndex, batch_size: usize) -> Result> { + let batch = self.batch.slice(self.start_row, self.num_rows); + let schema = batch.schema(); + + let field = schema.field(index.value_index); + let array = batch.column(index.value_index); + + // No values for this field are in the array so it does not + // contribute to a series. + if field.is_nullable() && Self::is_all_null(array) { + return Ok(None); + } + + let tags = self.create_frame_tags(schema.field(index.value_index).name()); + + let mut timestamps = compute::kernels::nullif::nullif( + batch.column(index.timestamp_index), + &compute::is_null(array).expect("is_null"), + ) + .expect("null handling") + .as_any() + .downcast_ref::() + .unwrap() + .extract_batched_values(batch_size); + timestamps.shrink_to_fit(); + + let data = match array.data_type() { + ArrowDataType::Utf8 => { + let values = array + .as_any() + .downcast_ref::() + .unwrap() + .extract_batched_values(batch_size); + Data::StringPoints(build_batches(timestamps, values)) + } + ArrowDataType::Float64 => { + let values = array + .as_any() + .downcast_ref::() + .unwrap() + .extract_batched_values(batch_size); + Data::FloatPoints(build_batches(timestamps, values)) + } + ArrowDataType::Int64 => { + let values = array + .as_any() + .downcast_ref::() + .unwrap() + .extract_batched_values(batch_size); + Data::IntegerPoints(build_batches(timestamps, values)) + } + ArrowDataType::UInt64 => { + let values = array + .as_any() + .downcast_ref::() + .unwrap() + .extract_batched_values(batch_size); + Data::UnsignedPoints(build_batches(timestamps, values)) + } + ArrowDataType::Boolean => { + let values = array + .as_any() + .downcast_ref::() + .unwrap() + .extract_batched_values(batch_size); + Data::BooleanPoints(build_batches(timestamps, values)) + } + _ => { + return UnsupportedDataTypeSnafu { + data_type: array.data_type().clone(), + } + .fail(); + } + }; + + Ok(Some(Series { tags, data })) + } + + /// Create the tag=value pairs for this series set, adding + /// adding the _f and _m tags for the field name and measurement + fn create_frame_tags(&self, field_name: &str) -> Vec { + // Add special _field and _measurement tags and return them in + // lexicographical (sorted) order + + let mut all_tags = self + .tags + .iter() + .cloned() + .chain([ + (Arc::from(FIELD_COLUMN_NAME), Arc::from(field_name)), + ( + Arc::from(MEASUREMENT_COLUMN_NAME), + Arc::clone(&self.table_name), + ), + ]) + .collect::>(); + + // sort by name + all_tags.sort_by(|(key1, _value), (key2, _value2)| key1.cmp(key2)); + + all_tags + .into_iter() + .map(|(key, value)| Tag { key, value }) + .collect() + } +} + +/// Zip together nested vectors of timestamps and values to create batches of points +fn build_batches(timestamps: Vec>, values: Vec>) -> Vec> { + timestamps + .into_iter() + .zip(values) + .map(|(timestamps, values)| Batch { timestamps, values }) + .collect() +} + +/// Represents a group of `Series` +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct Group { + /// Contains *ALL* tag keys (not just those used for grouping) + pub tag_keys: Vec>, + + /// Contains the values that define the group (may be values from + /// fields other than tags). + /// + /// the values of the group tags that defined the group. + /// For example, + /// + /// If there were tags `t0`, `t1`, and `t2`, and the query had + /// group_keys of `[t1, t2]` then this list would have the values + /// of the t1 and t2 columns + pub partition_key_vals: Vec>, +} + +impl fmt::Display for Group { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Group tag_keys: ")?; + fmt_strings(f, &self.tag_keys)?; + write!(f, " partition_key_vals: ")?; + fmt_strings(f, &self.partition_key_vals)?; + Ok(()) + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Either { + Series(Series), + Group(Group), +} + +impl From for Either { + fn from(value: Series) -> Self { + Self::Series(value) + } +} + +impl From for Either { + fn from(value: Group) -> Self { + Self::Group(value) + } +} + +impl fmt::Display for Either { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Series(series) => series.fmt(f), + Self::Group(group) => group.fmt(f), + } + } +} + +fn fmt_strings(f: &mut fmt::Formatter<'_>, strings: &[Arc]) -> fmt::Result { + let mut first = true; + strings.iter().try_for_each(|item| { + if !first { + write!(f, ", ")?; + } else { + first = false; + } + write!(f, "{item}") + }) +} + +trait ExtractBatchedValues { + /// Extracts rows as a vector, + /// for all rows `i` where `valid[i]` is set + fn extract_batched_values(&self, batch_size: usize) -> Vec>; +} + +/// Implements extract_batched_values for Arrow arrays. +macro_rules! extract_batched_values_impl { + ($DATA_TYPE:ty) => { + extract_batched_values_impl! { $DATA_TYPE, identity } + }; + ($DATA_TYPE:ty, $ITER_ADAPTER:expr) => { + fn extract_batched_values(&self, batch_size: usize) -> Vec> { + let num_batches = 1 + self.len() / batch_size; + let mut batches = Vec::with_capacity(num_batches); + + let mut v = Vec::with_capacity(batch_size); + for e in $ITER_ADAPTER(self.iter().flatten()) { + if v.len() >= batch_size { + batches.push(v); + v = Vec::with_capacity(batch_size); + } + v.push(e); + } + if !v.is_empty() { + v.shrink_to_fit(); + batches.push(v); + } + batches.shrink_to_fit(); + batches + } + }; +} + +fn identity(t: T) -> T { + t +} + +fn to_owned_string<'a, I>(i: I) -> impl Iterator +where + I: Iterator, +{ + i.map(str::to_string) +} + +impl ExtractBatchedValues for StringArray { + extract_batched_values_impl! { String, to_owned_string } +} + +impl ExtractBatchedValues for Int64Array { + extract_batched_values_impl! {i64} +} + +impl ExtractBatchedValues for UInt64Array { + extract_batched_values_impl! {u64} +} + +impl ExtractBatchedValues for Float64Array { + extract_batched_values_impl! {f64} +} + +impl ExtractBatchedValues for BooleanArray { + extract_batched_values_impl! {bool} +} + +impl ExtractBatchedValues for TimestampNanosecondArray { + extract_batched_values_impl! {i64} +} + +#[cfg(test)] +mod tests { + use crate::exec::field::FieldIndexes; + use arrow::{compute::concat_batches, record_batch::RecordBatch}; + + use super::*; + + fn series_set_to_series_strings(series_set: SeriesSet, batch_size: usize) -> Vec { + let series: Vec = series_set.try_into_series(batch_size).unwrap(); + + let series: Vec = series.into_iter().map(|s| s.to_string()).collect(); + + series + .iter() + .flat_map(|s| s.split('\n')) + .map(|s| s.to_string()) + .collect() + } + + #[test] + fn test_series_set_conversion() { + let series_set = SeriesSet { + table_name: Arc::from("the_table"), + tags: vec![(Arc::from("tag1"), Arc::from("val1"))], + field_indexes: FieldIndexes::from_timestamp_and_value_indexes(5, &[0, 1, 2, 3, 4]), + start_row: 1, + num_rows: 4, + batch: make_record_batch(), + }; + + let series_strings = series_set_to_series_strings(series_set, 3); + + let expected = vec![ + "Series tags={_field=string_field, _measurement=the_table, tag1=val1}", + " StringPoints batches: [Batch { timestamps: [2000, 3000, 4000], values: [\"bar\", \"baz\", \"bar\"] }, Batch { timestamps: [5000], values: [\"baz\"] }]", + "Series tags={_field=int_field, _measurement=the_table, tag1=val1}", + " IntegerPoints batches: [Batch { timestamps: [2000, 3000, 4000], values: [2, 3, 4] }, Batch { timestamps: [5000], values: [5] }]", + "Series tags={_field=uint_field, _measurement=the_table, tag1=val1}", + " UnsignedPoints batches: [Batch { timestamps: [2000, 3000, 4000], values: [22, 33, 44] }, Batch { timestamps: [5000], values: [55] }]", + "Series tags={_field=float_field, _measurement=the_table, tag1=val1}", + " FloatPoints batches: [Batch { timestamps: [2000, 3000, 4000], values: [20.1, 30.1, 40.1] }, Batch { timestamps: [5000], values: [50.1] }]", + "Series tags={_field=boolean_field, _measurement=the_table, tag1=val1}", + " BooleanPoints batches: [Batch { timestamps: [2000, 3000, 4000], values: [false, true, false] }, Batch { timestamps: [5000], values: [true] }]", + ]; + + assert_eq!( + series_strings, expected, + "Expected:\n{expected:#?}\nActual:\n{series_strings:#?}" + ); + } + + #[test] + fn test_series_set_conversion_mixed_case_tags() { + let time1_array: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![1, 2, 3])); + let string1_array: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar", "baz"])); + + let batch = RecordBatch::try_from_iter(vec![ + ("time1", time1_array as ArrayRef), + ("string_field1", string1_array), + ]) + .expect("created new record batch"); + + let series_set = SeriesSet { + table_name: Arc::from("the_table"), + tags: vec![ + (Arc::from("CAPITAL_TAG"), Arc::from("the_value")), + (Arc::from("tag1"), Arc::from("val1")), + ], + // field indexes are (value, time) + field_indexes: FieldIndexes::from_slice(&[(1, 0)]), + start_row: 1, + num_rows: 2, + batch, + }; + + let series_strings = series_set_to_series_strings(series_set, 100); + + // expect CAPITAL_TAG is before `_field` and `_measurement` tags + // (as that is the correct lexicographical ordering) + let expected = vec![ + "Series tags={CAPITAL_TAG=the_value, _field=string_field1, _measurement=the_table, tag1=val1}", + " StringPoints batches: [Batch { timestamps: [2, 3], values: [\"bar\", \"baz\"] }]", + ]; + + assert_eq!( + series_strings, expected, + "Expected:\n{expected:#?}\nActual:\n{series_strings:#?}" + ); + } + + #[test] + fn test_series_set_conversion_different_time_columns() { + let time1_array: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![1, 2, 3])); + let string1_array: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar", "baz"])); + let time2_array: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![3, 4, 5])); + let string2_array: ArrayRef = Arc::new(StringArray::from(vec!["boo", "far", "faz"])); + + let batch = RecordBatch::try_from_iter(vec![ + ("time1", time1_array as ArrayRef), + ("string_field1", string1_array), + ("time2", time2_array), + ("string_field2", string2_array), + ]) + .expect("created new record batch"); + + let series_set = SeriesSet { + table_name: Arc::from("the_table"), + tags: vec![(Arc::from("tag1"), Arc::from("val1"))], + // field indexes are (value, time) + field_indexes: FieldIndexes::from_slice(&[(3, 2), (1, 0)]), + start_row: 1, + num_rows: 2, + batch, + }; + + let series_strings = series_set_to_series_strings(series_set, 100); + + let expected = vec![ + "Series tags={_field=string_field2, _measurement=the_table, tag1=val1}", + " StringPoints batches: [Batch { timestamps: [4, 5], values: [\"far\", \"faz\"] }]", + "Series tags={_field=string_field1, _measurement=the_table, tag1=val1}", + " StringPoints batches: [Batch { timestamps: [2, 3], values: [\"bar\", \"baz\"] }]", + ]; + + assert_eq!( + series_strings, expected, + "Expected:\n{expected:#?}\nActual:\n{series_strings:#?}" + ); + } + + #[test] + fn test_series_set_conversion_with_entirely_null_field() { + // single series + let tag_array: ArrayRef = Arc::new(StringArray::from(vec!["MA", "MA", "MA", "MA"])); + let int_array: ArrayRef = Arc::new(Int64Array::from(vec![None, None, None, None])); + let float_array: ArrayRef = Arc::new(Float64Array::from(vec![ + Some(10.1), + Some(20.1), + None, + Some(40.1), + ])); + + let timestamp_array: ArrayRef = + Arc::new(TimestampNanosecondArray::from(vec![1000, 2000, 3000, 4000])); + + let batch = RecordBatch::try_from_iter_with_nullable(vec![ + ("state", tag_array, true), + ("int_field", int_array, true), + ("float_field", float_array, true), + ("time", timestamp_array, false), + ]) + .expect("created new record batch"); + + let series_set = SeriesSet { + table_name: Arc::from("the_table"), + tags: vec![(Arc::from("state"), Arc::from("MA"))], + field_indexes: FieldIndexes::from_timestamp_and_value_indexes(3, &[1, 2]), + start_row: 0, + num_rows: batch.num_rows(), + batch: batch.clone(), + }; + + // Expect only a single series (for the data in float_field, int_field is all + // nulls) + let series_strings = series_set_to_series_strings(series_set, 100); + + let expected = vec![ + "Series tags={_field=float_field, _measurement=the_table, state=MA}", + " FloatPoints batches: [Batch { timestamps: [1000, 2000, 4000], values: [10.1, 20.1, 40.1] }]", + ]; + + assert_eq!( + series_strings, expected, + "Expected:\n{expected:#?}\nActual:\n{series_strings:#?}" + ); + + // Multi-batch case + // We can just append record batches here because the tag field does not change + let batch = repeat_batch(3, &batch); + let series_set = SeriesSet { + table_name: Arc::from("the_table"), + tags: vec![(Arc::from("state"), Arc::from("MA"))], + field_indexes: FieldIndexes::from_timestamp_and_value_indexes(3, &[1, 2]), + start_row: 0, + num_rows: batch.num_rows(), + batch, + }; + + let series_strings = series_set_to_series_strings(series_set, 4); + let expected = vec![ + "Series tags={_field=float_field, _measurement=the_table, state=MA}", + " FloatPoints batches: [Batch { timestamps: [1000, 2000, 4000, 1000], values: [10.1, 20.1, 40.1, 10.1] }, Batch { timestamps: [2000, 4000, 1000, 2000], values: [20.1, 40.1, 10.1, 20.1] }, Batch { timestamps: [4000], values: [40.1] }]", + ]; + + assert_eq!( + series_strings, expected, + "Expected:\n{expected:#?}\nActual:\n{series_strings:#?}" + ); + } + + #[test] + fn test_series_set_conversion_with_some_null_fields() { + // single series + let tag_array = StringArray::from(vec!["MA", "MA"]); + let string_array = StringArray::from(vec![None, Some("foo")]); + let float_array = Float64Array::from(vec![None, Some(1.0)]); + let int_array = Int64Array::from(vec![None, Some(-10)]); + let uint_array = UInt64Array::from(vec![None, Some(100)]); + let bool_array = BooleanArray::from(vec![None, Some(true)]); + + let timestamp_array = TimestampNanosecondArray::from(vec![1000, 2000]); + + let batch = RecordBatch::try_from_iter_with_nullable(vec![ + ("state", Arc::new(tag_array) as ArrayRef, true), + ("string_field", Arc::new(string_array), true), + ("float_field", Arc::new(float_array), true), + ("int_field", Arc::new(int_array), true), + ("uint_field", Arc::new(uint_array), true), + ("bool_field", Arc::new(bool_array), true), + ("time", Arc::new(timestamp_array), false), + ]) + .expect("created new record batch"); + + let series_set = SeriesSet { + table_name: Arc::from("the_table"), + tags: vec![(Arc::from("state"), Arc::from("MA"))], + field_indexes: FieldIndexes::from_timestamp_and_value_indexes(6, &[1, 2, 3, 4, 5]), + start_row: 0, + num_rows: batch.num_rows(), + batch: batch.clone(), + }; + + // Expect only a single series (for the data in float_field, int_field is all + // nulls) + let series_strings = series_set_to_series_strings(series_set, 100); + + let expected = vec![ + "Series tags={_field=string_field, _measurement=the_table, state=MA}", + " StringPoints batches: [Batch { timestamps: [2000], values: [\"foo\"] }]", + "Series tags={_field=float_field, _measurement=the_table, state=MA}", + " FloatPoints batches: [Batch { timestamps: [2000], values: [1.0] }]", + "Series tags={_field=int_field, _measurement=the_table, state=MA}", + " IntegerPoints batches: [Batch { timestamps: [2000], values: [-10] }]", + "Series tags={_field=uint_field, _measurement=the_table, state=MA}", + " UnsignedPoints batches: [Batch { timestamps: [2000], values: [100] }]", + "Series tags={_field=bool_field, _measurement=the_table, state=MA}", + " BooleanPoints batches: [Batch { timestamps: [2000], values: [true] }]", + ]; + + assert_eq!( + series_strings, expected, + "Expected:\n{expected:#?}\nActual:\n{series_strings:#?}" + ); + + // multi-batch case + + // the tag columns have just a single value so we can just repeat the original batch to + // generate more rows + let batch = repeat_batch(4, &batch); + let series_set = SeriesSet { + table_name: Arc::from("the_table"), + tags: vec![(Arc::from("state"), Arc::from("MA"))], + field_indexes: FieldIndexes::from_timestamp_and_value_indexes(6, &[1, 2, 3, 4, 5]), + start_row: 0, + num_rows: batch.num_rows(), + batch, + }; + + let series_strings = series_set_to_series_strings(series_set, 3); + + let expected = vec![ + "Series tags={_field=string_field, _measurement=the_table, state=MA}", + " StringPoints batches: [Batch { timestamps: [2000, 2000, 2000], values: [\"foo\", \"foo\", \"foo\"] }, Batch { timestamps: [2000], values: [\"foo\"] }]", + "Series tags={_field=float_field, _measurement=the_table, state=MA}", + " FloatPoints batches: [Batch { timestamps: [2000, 2000, 2000], values: [1.0, 1.0, 1.0] }, Batch { timestamps: [2000], values: [1.0] }]", + "Series tags={_field=int_field, _measurement=the_table, state=MA}", + " IntegerPoints batches: [Batch { timestamps: [2000, 2000, 2000], values: [-10, -10, -10] }, Batch { timestamps: [2000], values: [-10] }]", + "Series tags={_field=uint_field, _measurement=the_table, state=MA}", + " UnsignedPoints batches: [Batch { timestamps: [2000, 2000, 2000], values: [100, 100, 100] }, Batch { timestamps: [2000], values: [100] }]", + "Series tags={_field=bool_field, _measurement=the_table, state=MA}", + " BooleanPoints batches: [Batch { timestamps: [2000, 2000, 2000], values: [true, true, true] }, Batch { timestamps: [2000], values: [true] }]", + ]; + + assert_eq!( + series_strings, expected, + "Expected:\n{expected:#?}\nActual:\n{series_strings:#?}" + ); + } + + fn make_record_batch() -> RecordBatch { + let string_array: ArrayRef = Arc::new(StringArray::from(vec![ + "foo", "bar", "baz", "bar", "baz", "foo", + ])); + let int_array: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5, 6])); + let uint_array: ArrayRef = Arc::new(UInt64Array::from(vec![11, 22, 33, 44, 55, 66])); + let float_array: ArrayRef = + Arc::new(Float64Array::from(vec![10.1, 20.1, 30.1, 40.1, 50.1, 60.1])); + let bool_array: ArrayRef = Arc::new(BooleanArray::from(vec![ + true, false, true, false, true, false, + ])); + + let timestamp_array: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![ + 1000, 2000, 3000, 4000, 5000, 6000, + ])); + + RecordBatch::try_from_iter_with_nullable(vec![ + ("string_field", string_array, true), + ("int_field", int_array, true), + ("uint_field", uint_array, true), + ("float_field", float_array, true), + ("boolean_field", bool_array, true), + ("time", timestamp_array, true), + ]) + .expect("created new record batch") + } + + fn repeat_batch(count: usize, rb: &RecordBatch) -> RecordBatch { + concat_batches(&rb.schema(), std::iter::repeat(rb).take(count)).unwrap() + } +} diff --git a/iox_query/src/exec/sleep.rs b/iox_query/src/exec/sleep.rs new file mode 100644 index 0000000..b7fa505 --- /dev/null +++ b/iox_query/src/exec/sleep.rs @@ -0,0 +1,265 @@ +/// Implementation of a "sleep" operation in DataFusion. +/// +/// The sleep operation passes through its input data and sleeps asynchronously for a duration determined by an +/// expression. The async sleep is implemented as a special [execution plan](SleepExpr) so we can perform this as part +/// of the async data stream. In contrast to a UDF, this will NOT block any threads. +use std::{sync::Arc, time::Duration}; + +use arrow::{ + array::{Array, Float32Array, Float64Array, Int64Array}, + datatypes::{DataType, SchemaRef, TimeUnit}, +}; +use datafusion::{ + common::DFSchemaRef, + error::DataFusionError, + execution::{context::SessionState, TaskContext}, + logical_expr::{LogicalPlan, UserDefinedLogicalNodeCore}, + physical_plan::{ + stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, + PhysicalExpr, SendableRecordBatchStream, Statistics, + }, + physical_planner::PhysicalPlanner, + prelude::Expr, +}; +use futures::TryStreamExt; + +/// Logical plan note that represents a "sleep" operation. +/// +/// This will be lowered to [`SleepExpr`]. +/// +/// See [module](super) docs for more details. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct SleepNode { + input: LogicalPlan, + duration: Vec, +} + +impl SleepNode { + pub fn new(input: LogicalPlan, duration: Vec) -> Self { + Self { input, duration } + } + + pub fn plan( + &self, + planner: &dyn PhysicalPlanner, + logical_inputs: &[&LogicalPlan], + physical_inputs: &[Arc], + session_state: &SessionState, + ) -> Result { + let duration = self + .duration + .iter() + .map(|e| { + planner.create_physical_expr( + e, + logical_inputs[0].schema(), + &physical_inputs[0].schema(), + session_state, + ) + }) + .collect::, _>>()?; + Ok(SleepExpr::new(Arc::clone(&physical_inputs[0]), duration)) + } +} + +impl UserDefinedLogicalNodeCore for SleepNode { + fn name(&self) -> &str { + "Sleep" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + self.duration.clone() + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let duration = self + .duration + .iter() + .map(|e| e.to_string()) + .collect::>() + .join(", "); + + write!(f, "{}: duration=[{}]", self.name(), duration) + } + + fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + Self::new(inputs[0].clone(), exprs.to_vec()) + } +} + +/// Physical node that implements a "sleep" operation. +/// +/// This was lowered from [`SleepNode`]. +/// +/// See [module](super) docs for more details. +#[derive(Debug)] +pub struct SleepExpr { + /// Input data. + input: Arc, + + /// Expression that determines the sum of the sleep duration. + duration: Vec>, +} + +impl SleepExpr { + pub fn new(input: Arc, duration: Vec>) -> Self { + Self { input, duration } + } +} + +impl DisplayAs for SleepExpr { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let duration = self + .duration + .iter() + .map(|e| e.to_string()) + .collect::>() + .join(", "); + + write!(f, "Sleep: duration=[{}]", duration) + } + } + } +} + +impl ExecutionPlan for SleepExpr { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.input.schema() + } + + fn output_partitioning(&self) -> datafusion::physical_plan::Partitioning { + self.input.output_partitioning() + } + + fn output_ordering(&self) -> Option<&[datafusion::physical_expr::PhysicalSortExpr]> { + self.input.output_ordering() + } + + fn children(&self) -> Vec> { + vec![Arc::clone(&self.input)] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion::error::Result> { + assert_eq!(children.len(), 1); + + Ok(Arc::new(Self::new( + Arc::clone(&children[0]), + self.duration.clone(), + ))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> datafusion::error::Result { + let stream = self.input.execute(partition, context)?; + + let duration = self.duration.clone(); + let stream = RecordBatchStreamAdapter::new( + stream.schema(), + stream.and_then(move |batch| { + let duration = duration.clone(); + + async move { + let mut sum = Duration::ZERO; + for expr in duration { + let array = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let d = array_to_duration(&array)?; + if let Some(d) = d { + sum += d; + } + } + if !sum.is_zero() { + tokio::time::sleep(sum).await; + } + Ok(batch) + } + }), + ); + Ok(Box::pin(stream)) + } + + fn statistics(&self) -> Result { + Ok(Statistics::new_unknown(&self.schema())) + } +} + +fn array_to_duration(array: &dyn Array) -> Result, DataFusionError> { + match array.data_type() { + DataType::Null => Ok(None), + DataType::Duration(tunit) => { + let array = arrow::compute::cast(array, &DataType::Int64)?; + let array = array + .as_any() + .downcast_ref::() + .expect("just casted"); + let Some(sum) = arrow::compute::sum(array) else { + return Ok(None); + }; + if sum < 0 { + return Err(DataFusionError::Execution(format!( + "duration must be non-negative but is {sum}{tunit:?}" + ))); + } + let sum = sum as u64; + let duration = match tunit { + TimeUnit::Second => Duration::from_secs(sum), + TimeUnit::Millisecond => Duration::from_millis(sum), + TimeUnit::Microsecond => Duration::from_micros(sum), + TimeUnit::Nanosecond => Duration::from_nanos(sum), + }; + Ok(Some(duration)) + } + DataType::Float32 => { + let array = array + .as_any() + .downcast_ref::() + .expect("just checked"); + let Some(sum) = arrow::compute::sum(array) else { + return Ok(None); + }; + if sum < 0.0 || !sum.is_finite() { + return Err(DataFusionError::Execution(format!( + "duration must be non-negative but is {sum}s" + ))); + } + Ok(Some(Duration::from_secs_f32(sum))) + } + DataType::Float64 => { + let array = array + .as_any() + .downcast_ref::() + .expect("just checked"); + let Some(sum) = arrow::compute::sum(array) else { + return Ok(None); + }; + if sum < 0.0 || !sum.is_finite() { + return Err(DataFusionError::Execution(format!( + "duration must be non-negative but is {sum}s" + ))); + } + Ok(Some(Duration::from_secs_f64(sum))) + } + other => Err(DataFusionError::Internal(format!( + "Expected duration pattern to sleep(...), got: {other:?}" + ))), + } +} diff --git a/iox_query/src/exec/split.rs b/iox_query/src/exec/split.rs new file mode 100644 index 0000000..3010884 --- /dev/null +++ b/iox_query/src/exec/split.rs @@ -0,0 +1,931 @@ +//! This module contains a DataFusion extension node to "split" a +//! stream based on an expression. +//! +//! All rows for which the expression are true are sent to partition +//! `0` and all other rows are sent to partition `1`. +//! +//! There are corresponding [`LogicalPlan`] ([`StreamSplitNode`]) and +//! [`ExecutionPlan`] ([`StreamSplitExec`]) implementations, which are +//! typically used as shown in the following diagram: +//! +//! +//! ```text +//! partition 0 partition 1 +//! ▲ ▲ +//! │ │ +//! └────────────┬──────────┘ +//! │ +//! │ +//! ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +//! StreamSplitExec │ +//! │ expr +//! ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ +//! ▲ +//! │ +//! ┌────────────────────────┐ +//! │ Union │ +//! │ │ +//! └────────────────────────┘ +//! ▲ +//! │ +//! +//! Other IOxScan code +//! ┌────────────────────────┐ (Filter, Dedup, etc) +//! │ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │ ... +//! │ StreamSplit │ │ +//! │ └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │ ▲ +//! │ Extension │ │ +//! └────────────────────────┘ │ +//! ▲ ┌────────────────────────┐ +//! │ │ TableProvider │ +//! ┌────────────────────────┐ │ │ +//! │ TableScan │ └────────────────────────┘ +//! │ │ +//! └────────────────────────┘ +//! +//! Execution Plan +//! Logical Plan (Physical Plan) +//! ``` + +use std::{ + fmt::{self, Debug}, + sync::Arc, +}; + +use arrow::{ + array::{as_boolean_array, Array, ArrayRef, BooleanArray}, + compute::{self, filter_record_batch}, + datatypes::SchemaRef, + record_batch::RecordBatch, +}; +use datafusion::{ + common::DFSchemaRef, + error::{DataFusionError, Result}, + execution::context::TaskContext, + logical_expr::{Expr, LogicalPlan, UserDefinedLogicalNodeCore}, + physical_expr::PhysicalSortRequirement, + physical_plan::{ + expressions::PhysicalSortExpr, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput}, + ColumnarValue, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + PhysicalExpr, SendableRecordBatchStream, Statistics, + }, + scalar::ScalarValue, +}; + +use datafusion_util::{watch::WatchedTask, AdapterStream}; +use futures::StreamExt; +use observability_deps::tracing::*; +use parking_lot::Mutex; +use tokio::sync::mpsc::Sender; + +/// Implements stream splitting described in `make_stream_split` +/// +/// The resulting execution plan always produces exactly split_exprs's length + 1 partitions: +/// +/// * partition i (i < split_exprs.len()) are the rows for which the `split_expr[i]` +/// evaluates to true. If the rows are evaluated true for both `split_expr[i]` and +/// `split_expr[j]`, where i < j, the rows will be sent to partition i. However, +/// this will be mostly used in the use case of range expressions (e.g: [2 <= x, 2= x <= 5]) +/// in which rows are only evaluated to true in at most one of the expressions. +/// * partition n (n = partition split_exprs.len()) are the rows for which all split_exprs +/// do not evaluate to true (e.g. Null or false) +#[derive(Hash, PartialEq, Eq)] +pub struct StreamSplitNode { + input: LogicalPlan, + split_exprs: Vec, +} + +impl StreamSplitNode { + /// Create a new `StreamSplitNode` using `split_exprs` to divide the + /// rows. All `split_exprs` must evaluate to a boolean otherwise a + /// runtime error will occur. + pub fn new(input: LogicalPlan, split_exprs: Vec) -> Self { + Self { input, split_exprs } + } + + pub fn split_exprs(&self) -> &Vec { + &self.split_exprs + } +} + +impl Debug for StreamSplitNode { + /// Use explain format for the Debug format. + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.fmt_for_explain(f) + } +} + +impl UserDefinedLogicalNodeCore for StreamSplitNode { + fn name(&self) -> &str { + "StreamSplit" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + /// Schema is the same as the input schema + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + self.split_exprs.clone() + } + + fn fmt_for_explain(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{} split_expr={:?}", self.name(), self.split_exprs) + } + + fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + assert_eq!(inputs.len(), 1, "StreamSplitNode: input sizes inconsistent"); + Self { + input: inputs[0].clone(), + split_exprs: (*exprs).to_vec(), + } + } +} + +/// Tracks the state of the physical operator +enum State { + New, + Running { + streams: Vec>, + }, +} + +/// Physical operator that implements steam splitting operation +pub struct StreamSplitExec { + state: Mutex, + input: Arc, + split_exprs: Vec>, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, +} + +impl StreamSplitExec { + pub fn new(input: Arc, split_exprs: Vec>) -> Self { + let state = Mutex::new(State::New); + Self { + state, + input, + split_exprs, + metrics: ExecutionPlanMetricsSet::new(), + } + } +} + +impl Debug for StreamSplitExec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "StreamSplitExec {:?}", self.split_exprs) + } +} + +impl ExecutionPlan for StreamSplitExec { + fn as_any(&self) -> &(dyn std::any::Any + 'static) { + self + } + + fn schema(&self) -> SchemaRef { + self.input.schema() + } + + /// Always produces exactly two outputs + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(self.split_exprs.len() + 1) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.input.output_ordering() + } + + /// Always require a single input (eventually we might imagine + /// running this on multiple partitions concurrently to compute + /// the splits in parallel, but not now) + fn required_input_distribution(&self) -> Vec { + vec![Distribution::SinglePartition] + } + + fn required_input_ordering(&self) -> Vec>> { + // require that the output ordering of the child is preserved + // (so that this node logically splits what was desired) + let requirement = self + .input + .output_ordering() + .map(PhysicalSortRequirement::from_sort_exprs); + + vec![requirement] + } + + fn children(&self) -> Vec> { + vec![Arc::clone(&self.input)] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + match children.len() { + 1 => Ok(Arc::new(Self::new( + Arc::clone(&children[0]), + self.split_exprs.clone(), + ))), + _ => Err(DataFusionError::Internal( + "StreamSplitExec wrong number of children".to_string(), + )), + } + } + + /// Stream split has multiple partitions from 0 to n + /// Each partition i includes rows for which `split_exprs[i]` evaluate to true + /// + /// # Deadlock + /// + /// This will deadlock unless all partitions are consumed from + /// concurrently. Failing to consume from one partition blocks the other + /// partitions from progressing. + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + trace!(partition, "Start SplitExec::execute"); + self.start_if_needed(context)?; + + let mut state = self.state.lock(); + match &mut (*state) { + State::New => panic!("should have been initialized"), + State::Running { streams } => { + assert!(partition < streams.len()); + let stream = streams[partition].take().unwrap_or_else(|| { + panic!("Error executing stream #{partition} of StreamSplitExec"); + }); + trace!(partition, "End SplitExec::execute"); + Ok(stream) + } + } + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + // For now, don't return any statistics (in the future we + // could potentially estimate the output cardinalities) + Ok(Statistics::new_unknown(&self.schema())) + } +} + +impl DisplayAs for StreamSplitExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "StreamSplitExec") + } + } + } +} + +impl StreamSplitExec { + /// if in State::New, sets up the output running and sets self.state --> `Running` + fn start_if_needed(&self, context: Arc) -> Result<()> { + let mut state = self.state.lock(); + if matches!(*state, State::Running { .. }) { + return Ok(()); + } + + let num_input_streams = self.input.output_partitioning().partition_count(); + assert_eq!( + num_input_streams, 1, + "need exactly one input partition for stream split exec" + ); + + trace!("Setting up SplitStreamExec state"); + let input_stream = self.input.execute(0, context)?; + + let split_exprs = self.split_exprs.clone(); + + let num_streams = split_exprs.len() + 1; + let mut baseline_metrics = Vec::with_capacity(num_streams); + let mut txs = Vec::with_capacity(num_streams); + let mut rxs = Vec::with_capacity(num_streams); + for i in 0..num_streams { + baseline_metrics.push(BaselineMetrics::new(&self.metrics, i)); + let (tx, rx) = tokio::sync::mpsc::channel(2); + txs.push(tx); + rxs.push(rx); + } + + // launch the work on a different task, with a task to handle its output values + let fut = split_the_stream(input_stream, split_exprs, txs.clone(), baseline_metrics); + let handle = WatchedTask::new(fut, txs, "split"); + + let streams = rxs + .into_iter() + .map(|rx| { + Some(AdapterStream::adapt( + self.input.schema(), + rx, + Arc::clone(&handle), + )) + }) + .collect::>(); + + *state = State::Running { streams }; + + Ok(()) + } +} + +/// This function does the actual splitting: evaluates `split_exprs` on +/// each input [`RecordBatch`], and then sends the rows to the correct +/// output `tx[i]` +async fn split_the_stream( + mut input_stream: SendableRecordBatchStream, + split_exprs: Vec>, + tx: Vec>>, + baseline_metrics: Vec, +) -> std::result::Result<(), DataFusionError> { + assert_eq!(split_exprs.len() + 1, tx.len()); + assert_eq!(tx.len(), baseline_metrics.len()); + + let elapsed_computes = baseline_metrics + .iter() + .map(|b| b.elapsed_compute()) + .collect::>(); + + while let Some(batch) = input_stream.next().await { + let batch = batch?; + trace!(num_rows = batch.num_rows(), "Processing batch"); + + // All streams are not done yet + let mut tx_done = tx.iter().map(|_| false).collect::>(); + + // Get data from the current batch for each stream + let mut remaining_indices: Option = None; + for i in 0..split_exprs.len() { + let timer = elapsed_computes[i].timer(); + let expr = &split_exprs[i]; + + // Compute indices that meets this expr + let true_indices = expr.evaluate(&batch)?; + // Indices that does not meet this expr + let not_true_indices = negate(&true_indices)?; + + // Indices that do not meet all exprs + if let Some(not_true) = remaining_indices { + remaining_indices = Some( + and(¬_true, ¬_true_indices) + .expect("Error computing combining negating indices"), + ); + } else { + remaining_indices = Some(not_true_indices); + }; + + // data that meets expr + let true_batch = compute_batch(&batch, true_indices, false)?; + timer.done(); + + // record output counts + let true_batch = true_batch.record_output(&baseline_metrics[i]); + + // don't treat a hangup as an error, as it can also be caused + // by a LIMIT operation where the entire stream is not + // consumed) + if let Err(e) = tx[i].send(Ok(true_batch)).await { + debug!(%e, "Split tx[{}] hung up, ignoring", i); + tx_done[i] = true; + } + } + + // last stream of data gets values that did not get routed to other streams + let timer = elapsed_computes[elapsed_computes.len() - 1].timer(); + let remaining_indices = + remaining_indices.expect("The last set of indices of the split should have values"); + let final_not_true_batch = compute_batch(&batch, remaining_indices, true)?; + timer.done(); + + // record output counts + let final_not_true_batch = + final_not_true_batch.record_output(&baseline_metrics[elapsed_computes.len() - 1]); + + // don't treat a hangup as an error, as it can also be caused + // by a LIMIT operation where the entire stream is not + // consumed) + if let Err(e) = tx[elapsed_computes.len() - 1] + .send(Ok(final_not_true_batch)) + .await + { + debug!(%e, "Split tx[{}] hung up, ignoring", elapsed_computes.len()-1); + tx_done[elapsed_computes.len() - 1] = true; + } + + if tx_done.iter().all(|x| *x) { + debug!("All split tx ends have hung up, stopping loop"); + return Ok(()); + } + } + + trace!("Splitting done successfully"); + Ok(()) +} + +fn compute_batch( + input_batch: &RecordBatch, + indices: ColumnarValue, + last_batch: bool, +) -> Result { + let batch = match indices { + ColumnarValue::Array(indices) => { + let indices = indices.as_any().downcast_ref::().unwrap(); + + // include null for last batch + if last_batch && indices.null_count() > 0 { + // since !Null --> Null, but we want all the + // remaining rows, that are not in true_indicies, + // transform any nulls into true for this one + let mapped_indicies = indices.iter().map(|v| v.or(Some(true))).collect::>(); + + filter_record_batch(input_batch, &BooleanArray::from(mapped_indicies)) + } else { + filter_record_batch(input_batch, indices) + }? + } + ColumnarValue::Scalar(ScalarValue::Boolean(val)) => { + let empty_record_batch = RecordBatch::new_empty(input_batch.schema()); + match val { + Some(true) => input_batch.clone(), + Some(false) => empty_record_batch, + _ => panic!("mismatched boolean values: {val:?}"), + } + } + _ => { + panic!("mismatched array types"); + } + }; + + Ok(batch) +} + +/// compute the boolean compliment of the columnar value (which must be boolean) +fn negate(v: &ColumnarValue) -> Result { + match v { + ColumnarValue::Array(arr) => { + let arr = arr.as_any().downcast_ref::().ok_or_else(|| { + let msg = format!("Expected boolean array, but had type {:?}", arr.data_type()); + DataFusionError::Internal(msg) + })?; + let neg_array = Arc::new(compute::not(arr)?) as ArrayRef; + Ok(ColumnarValue::Array(neg_array)) + } + ColumnarValue::Scalar(val) => { + if let ScalarValue::Boolean(v) = val { + let not_v = v.map(|v| !v); + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(not_v))) + } else { + let msg = format!( + "Expected boolean literal, but got type {:?}", + val.data_type() + ); + Err(DataFusionError::Internal(msg)) + } + } + } +} + +fn and(left: &ColumnarValue, right: &ColumnarValue) -> Result { + match (left, right) { + (ColumnarValue::Array(arr_left), ColumnarValue::Array(arr_right)) => { + let arr_left = as_boolean_array(arr_left); + let arr_right = as_boolean_array(arr_right); + let and_array = Arc::new(compute::and(arr_left, arr_right)?) as ArrayRef; + Ok(ColumnarValue::Array(and_array)) + } + (ColumnarValue::Scalar(val_left), ColumnarValue::Scalar(val_right)) => { + if let (ScalarValue::Boolean(Some(v_left)), ScalarValue::Boolean(Some(v_right))) = + (val_left, val_right) + { + let and_val = v_left & v_right; + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(and_val)))) + } else { + let msg = format!( + "Expected two boolean literals, but got type {:?} and type {:?}", + val_left.data_type(), + val_right.data_type() + ); + Err(DataFusionError::Internal(msg)) + } + } + _ => { + panic!("Expected either two boolean arrays or two boolean scalars, but had type {:?} and type {:?}", left.data_type(), right.data_type()); + } + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Int64Array, StringArray}; + use arrow_util::assert_batches_sorted_eq; + use datafusion::{ + physical_plan::memory::MemoryExec, + prelude::{col, lit}, + }; + use datafusion_util::test_collect_partition; + + use crate::util::df_physical_expr; + + use super::*; + + #[tokio::test] + async fn test_basic_split() { + test_helpers::maybe_start_logging(); + let batch0 = RecordBatch::try_from_iter(vec![ + ( + "int_col", + Arc::new(Int64Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ( + "str_col", + Arc::new(StringArray::from(vec!["one", "two", "three"])) as ArrayRef, + ), + ]) + .unwrap(); + + let batch1 = RecordBatch::try_from_iter(vec![ + ( + "int_col", + Arc::new(Int64Array::from(vec![4, -2])) as ArrayRef, + ), + ( + "str_col", + Arc::new(StringArray::from(vec!["four", "negative 2"])) as ArrayRef, + ), + ]) + .unwrap(); + + let input = make_input(vec![vec![batch0, batch1]]); + // int_col < 3 + let split_expr = df_physical_expr(input.schema(), col("int_col").lt(lit(3))).unwrap(); + let split_exec: Arc = + Arc::new(StreamSplitExec::new(input, vec![split_expr])); + + let output0 = test_collect_partition(Arc::clone(&split_exec), 0).await; + let expected = vec![ + "+---------+------------+", + "| int_col | str_col |", + "+---------+------------+", + "| -2 | negative 2 |", + "| 1 | one |", + "| 2 | two |", + "+---------+------------+", + ]; + assert_batches_sorted_eq!(&expected, &output0); + + let output1 = test_collect_partition(split_exec, 1).await; + let expected = vec![ + "+---------+---------+", + "| int_col | str_col |", + "+---------+---------+", + "| 3 | three |", + "| 4 | four |", + "+---------+---------+", + ]; + assert_batches_sorted_eq!(&expected, &output1); + } + + #[tokio::test] + async fn test_basic_split_multi_exprs() { + test_helpers::maybe_start_logging(); + let batch0 = RecordBatch::try_from_iter(vec![ + ( + "int_col", + Arc::new(Int64Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ( + "str_col", + Arc::new(StringArray::from(vec!["one", "two", "three"])) as ArrayRef, + ), + ]) + .unwrap(); + + let batch1 = RecordBatch::try_from_iter(vec![ + ( + "int_col", + Arc::new(Int64Array::from(vec![4, -2])) as ArrayRef, + ), + ( + "str_col", + Arc::new(StringArray::from(vec!["four", "negative 2"])) as ArrayRef, + ), + ]) + .unwrap(); + + let input = make_input(vec![vec![batch0, batch1]]); + // int_col < 2 + let split_expr1 = + df_physical_expr(input.schema(), col("int_col").lt(lit::(2))).unwrap(); + // 2 <= int_col < 3 + let expr = col("int_col") + .gt_eq(lit::(2)) + .and(col("int_col").lt(lit::(3))); + let split_expr2 = df_physical_expr(input.schema(), expr).unwrap(); + let split_exec: Arc = + Arc::new(StreamSplitExec::new(input, vec![split_expr1, split_expr2])); + + let output0 = test_collect_partition(Arc::clone(&split_exec), 0).await; + let expected = vec![ + "+---------+------------+", + "| int_col | str_col |", + "+---------+------------+", + "| -2 | negative 2 |", + "| 1 | one |", + "+---------+------------+", + ]; + assert_batches_sorted_eq!(&expected, &output0); + + let output1 = test_collect_partition(Arc::clone(&split_exec), 1).await; + let expected = vec![ + "+---------+---------+", + "| int_col | str_col |", + "+---------+---------+", + "| 2 | two |", + "+---------+---------+", + ]; + assert_batches_sorted_eq!(&expected, &output1); + + let output2 = test_collect_partition(split_exec, 2).await; + let expected = vec![ + "+---------+---------+", + "| int_col | str_col |", + "+---------+---------+", + "| 3 | three |", + "| 4 | four |", + "+---------+---------+", + ]; + assert_batches_sorted_eq!(&expected, &output2); + } + + #[tokio::test] + async fn test_constant_split() { + // test that it works with a constant expression + test_helpers::maybe_start_logging(); + let batch0 = RecordBatch::try_from_iter(vec![( + "int_col", + Arc::new(Int64Array::from(vec![1, 2, 3])) as ArrayRef, + )]) + .unwrap(); + + let input = make_input(vec![vec![batch0]]); + // use `false` to send all outputs to second stream + let split_expr = df_physical_expr(input.schema(), lit(false)).unwrap(); + let split_exec: Arc = + Arc::new(StreamSplitExec::new(input, vec![split_expr])); + + let output0 = test_collect_partition(Arc::clone(&split_exec), 0).await; + let expected = vec!["+---------+", "| int_col |", "+---------+", "+---------+"]; + assert_batches_sorted_eq!(&expected, &output0); + + let output1 = test_collect_partition(split_exec, 1).await; + let expected = vec![ + "+---------+", + "| int_col |", + "+---------+", + "| 1 |", + "| 2 |", + "| 3 |", + "+---------+", + ]; + assert_batches_sorted_eq!(&expected, &output1); + } + + #[tokio::test] + async fn test_constant_split_multi_exprs() { + // test that it works with a constant expression + test_helpers::maybe_start_logging(); + let batch0 = RecordBatch::try_from_iter(vec![( + "int_col", + Arc::new(Int64Array::from(vec![1, 2, 3])) as ArrayRef, + )]) + .unwrap(); + + // Test 1: 3 streams but all data is sent to the second one + let input = make_input(vec![vec![batch0.clone()]]); + // use `false` & `true` to send all outputs to second stream + let split_expr1 = df_physical_expr(input.schema(), lit(false)).unwrap(); + let split_expr2 = df_physical_expr(input.schema(), lit(true)).unwrap(); + let split_exec: Arc = + Arc::new(StreamSplitExec::new(input, vec![split_expr1, split_expr2])); + + let output0 = test_collect_partition(Arc::clone(&split_exec), 0).await; + let expected = vec!["+---------+", "| int_col |", "+---------+", "+---------+"]; + assert_batches_sorted_eq!(&expected, &output0); + + let output1 = test_collect_partition(Arc::clone(&split_exec), 1).await; + let expected = vec![ + "+---------+", + "| int_col |", + "+---------+", + "| 1 |", + "| 2 |", + "| 3 |", + "+---------+", + ]; + assert_batches_sorted_eq!(&expected, &output1); + + let output2 = test_collect_partition(split_exec, 2).await; + let expected = vec!["+---------+", "| int_col |", "+---------+", "+---------+"]; + assert_batches_sorted_eq!(&expected, &output2); + + // ----------------------- + // Test 2: 3 streams but all data is sent to the last one + let input = make_input(vec![vec![batch0.clone()]]); + + // use `false` & `false` to send all outputs to third stream + let split_expr1 = df_physical_expr(input.schema(), lit(false)).unwrap(); + let split_expr2 = df_physical_expr(input.schema(), lit(false)).unwrap(); + let split_exec: Arc = + Arc::new(StreamSplitExec::new(input, vec![split_expr1, split_expr2])); + + let output0 = test_collect_partition(Arc::clone(&split_exec), 0).await; + let expected = vec!["+---------+", "| int_col |", "+---------+", "+---------+"]; + assert_batches_sorted_eq!(&expected, &output0); + + let output1 = test_collect_partition(Arc::clone(&split_exec), 1).await; + let expected = vec!["+---------+", "| int_col |", "+---------+", "+---------+"]; + assert_batches_sorted_eq!(&expected, &output1); + + let output2 = test_collect_partition(Arc::clone(&split_exec), 2).await; + let expected = vec![ + "+---------+", + "| int_col |", + "+---------+", + "| 1 |", + "| 2 |", + "| 3 |", + "+---------+", + ]; + assert_batches_sorted_eq!(&expected, &output2); + + // ----------------------- + // Test 3: 3 streams but all data is sent to the first + let input = make_input(vec![vec![batch0]]); + + // use `true` & `false` to send all outputs to first stream + let split_expr1 = df_physical_expr(input.schema(), lit(true)).unwrap(); + let split_expr2 = df_physical_expr(input.schema(), lit(false)).unwrap(); + let split_exec: Arc = + Arc::new(StreamSplitExec::new(input, vec![split_expr1, split_expr2])); + + let output0 = test_collect_partition(Arc::clone(&split_exec), 0).await; + let expected = vec![ + "+---------+", + "| int_col |", + "+---------+", + "| 1 |", + "| 2 |", + "| 3 |", + "+---------+", + ]; + assert_batches_sorted_eq!(&expected, &output0); + + let output1 = test_collect_partition(Arc::clone(&split_exec), 1).await; + let expected = vec!["+---------+", "| int_col |", "+---------+", "+---------+"]; + assert_batches_sorted_eq!(&expected, &output1); + + let output2 = test_collect_partition(Arc::clone(&split_exec), 2).await; + let expected = vec!["+---------+", "| int_col |", "+---------+", "+---------+"]; + assert_batches_sorted_eq!(&expected, &output2); + } + + #[tokio::test] + async fn test_nulls() { + // test with null inputs (so rows evaluate to null) + + test_helpers::maybe_start_logging(); + let batch0 = RecordBatch::try_from_iter(vec![( + "int_col", + Arc::new(Int64Array::from(vec![Some(1), None, Some(2), Some(3)])) as ArrayRef, + )]) + .unwrap(); + + let input = make_input(vec![vec![batch0]]); + // int_col < 3 + let split_expr = df_physical_expr(input.schema(), col("int_col").lt(lit(3))).unwrap(); + let split_exec: Arc = + Arc::new(StreamSplitExec::new(input, vec![split_expr])); + + let output0 = test_collect_partition(Arc::clone(&split_exec), 0).await; + let expected = vec![ + "+---------+", + "| int_col |", + "+---------+", + "| 1 |", + "| 2 |", + "+---------+", + ]; + assert_batches_sorted_eq!(&expected, &output0); + + let output1 = test_collect_partition(split_exec, 1).await; + let expected = vec![ + "+---------+", + "| int_col |", + "+---------+", + "| |", + "| 3 |", + "+---------+", + ]; + assert_batches_sorted_eq!(&expected, &output1); + } + + #[tokio::test] + async fn test_nulls_multi_exprs() { + // test with null inputs (so rows evaluate to null) + + test_helpers::maybe_start_logging(); + let batch0 = RecordBatch::try_from_iter(vec![( + "int_col", + Arc::new(Int64Array::from(vec![Some(1), None, Some(2), Some(3)])) as ArrayRef, + )]) + .unwrap(); + + let input = make_input(vec![vec![batch0]]); + // int_col < 2 + let split_expr1 = + df_physical_expr(input.schema(), col("int_col").lt(lit::(2))).unwrap(); + // 2 <= int_col < 3 + let expr = col("int_col") + .gt_eq(lit::(2)) + .and(col("int_col").lt(lit::(3))); + let split_expr2 = df_physical_expr(input.schema(), expr).unwrap(); + let split_exec: Arc = + Arc::new(StreamSplitExec::new(input, vec![split_expr1, split_expr2])); + + let output0 = test_collect_partition(Arc::clone(&split_exec), 0).await; + let expected = vec![ + "+---------+", + "| int_col |", + "+---------+", + "| 1 |", + "+---------+", + ]; + assert_batches_sorted_eq!(&expected, &output0); + + let output1 = test_collect_partition(Arc::clone(&split_exec), 1).await; + let expected = vec![ + "+---------+", + "| int_col |", + "+---------+", + "| 2 |", + "+---------+", + ]; + assert_batches_sorted_eq!(&expected, &output1); + + let output2 = test_collect_partition(split_exec, 2).await; + let expected = vec![ + "+---------+", + "| int_col |", + "+---------+", + "| |", + "| 3 |", + "+---------+", + ]; + assert_batches_sorted_eq!(&expected, &output2); + } + + #[tokio::test] + #[should_panic(expected = "Expected boolean array, but had type Int64")] + async fn test_non_bool() { + // test non boolean expression (expect error) + + test_helpers::maybe_start_logging(); + let batch0 = RecordBatch::try_from_iter(vec![( + "int_col", + Arc::new(Int64Array::from(vec![Some(1), None, Some(2), Some(3)])) as ArrayRef, + )]) + .unwrap(); + + let input = make_input(vec![vec![batch0]]); + // int_col (not a boolean) + let split_expr = df_physical_expr(input.schema(), col("int_col")).unwrap(); + let split_exec: Arc = + Arc::new(StreamSplitExec::new(input, vec![split_expr])); + + test_collect_partition(split_exec, 0).await; + } + + fn make_input(partitions: Vec>) -> Arc { + let schema = partitions + .iter() + .flat_map(|p| p.iter()) + .map(|batch| batch.schema()) + .next() + .expect("must be at least one batch"); + + let projection = None; + let input = + MemoryExec::try_new(&partitions, schema, projection).expect("Created MemoryExec"); + Arc::new(input) + } +} diff --git a/iox_query/src/exec/stringset.rs b/iox_query/src/exec/stringset.rs new file mode 100644 index 0000000..69fe153 --- /dev/null +++ b/iox_query/src/exec/stringset.rs @@ -0,0 +1,149 @@ +//! This module contains the definition of a "StringSet" a set of +//! logical strings and the code to create them from record batches. + +use std::{collections::BTreeSet, sync::Arc}; + +use arrow::{ + array::{Array, DictionaryArray, StringArray}, + datatypes::{DataType, Int32Type, SchemaRef}, + record_batch::RecordBatch, +}; + +use snafu::{ensure, Snafu}; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display( + "Error extracting results from Record Batches: schema not a single Utf8 or string dictionary: {:?}", + schema + ))] + InternalSchemaWasNotString { schema: SchemaRef }, + + #[snafu(display("Internal error, unexpected null value"))] + InternalUnexpectedNull {}, + + #[snafu(display( + "Error reading record batch while converting to StringSet: {:?}", + source + ))] + ReadingRecordBatch { source: arrow::error::ArrowError }, +} + +pub type Result = std::result::Result; + +pub type StringSet = BTreeSet; +pub type StringSetRef = Arc; + +/// Trait to convert RecordBatch'y things into +/// `StringSetRef`s. Assumes that the input record batches each have a +/// single string column. Can return errors, so don't use +/// `std::convert::From` +pub trait IntoStringSet { + /// Convert this thing into a stringset + fn into_stringset(self) -> Result; +} + +impl IntoStringSet for &[&str] { + fn into_stringset(self) -> Result { + let set: StringSet = self.iter().map(|s| s.to_string()).collect(); + Ok(Arc::new(set)) + } +} + +/// Converts record batches into StringSets. +impl IntoStringSet for Vec { + fn into_stringset(self) -> Result { + let mut strings = StringSet::new(); + + // process the record batches one by one + for record_batch in self.into_iter() { + let num_rows = record_batch.num_rows(); + let schema = record_batch.schema(); + let fields = schema.fields(); + ensure!( + fields.len() == 1, + InternalSchemaWasNotStringSnafu { + schema: Arc::clone(&schema), + } + ); + + let field = &fields[0]; + + match field.data_type() { + DataType::Utf8 => { + let array = record_batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + add_utf8_array_to_stringset(&mut strings, array, num_rows)?; + } + DataType::Dictionary(key, value) + if key.as_ref() == &DataType::Int32 && value.as_ref() == &DataType::Utf8 => + { + let array = record_batch + .column(0) + .as_any() + .downcast_ref::>() + .unwrap(); + + add_utf8_dictionary_to_stringset(&mut strings, array, num_rows)?; + } + _ => InternalSchemaWasNotStringSnafu { + schema: Arc::clone(&schema), + } + .fail()?, + } + } + Ok(StringSetRef::new(strings)) + } +} + +fn add_utf8_array_to_stringset( + dest: &mut StringSet, + src: &StringArray, + num_rows: usize, +) -> Result<()> { + for i in 0..num_rows { + // Not sure how to handle a NULL -- StringSet contains + // Strings, not Option + if src.is_null(i) { + return InternalUnexpectedNullSnafu {}.fail(); + } else { + let src_value = src.value(i); + if !dest.contains(src_value) { + dest.insert(src_value.into()); + } + } + } + Ok(()) +} + +fn add_utf8_dictionary_to_stringset( + dest: &mut StringSet, + dictionary: &DictionaryArray, + num_rows: usize, +) -> Result<()> { + let keys = dictionary.keys(); + let values = dictionary.values(); + let values = values.as_any().downcast_ref::().unwrap(); + + // It might be quicker to construct an intermediate collection + // of unique indexes and then hydrate them + + for i in 0..num_rows { + // Not sure how to handle a NULL -- StringSet contains + // Strings, not Option + if keys.is_null(i) { + return InternalUnexpectedNullSnafu {}.fail(); + } else { + let idx = keys.value(i); + let src_value = values.value(idx as _); + if !dest.contains(src_value) { + dest.insert(src_value.into()); + } + } + } + Ok(()) +} diff --git a/iox_query/src/frontend.rs b/iox_query/src/frontend.rs new file mode 100644 index 0000000..7a6bd64 --- /dev/null +++ b/iox_query/src/frontend.rs @@ -0,0 +1,251 @@ +pub mod reorg; +pub mod sql; + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use datafusion::physical_plan::{ + metrics::{self, MetricValue}, + ExecutionPlan, ExecutionPlanVisitor, + }; + use datafusion_util::test_execute_partition; + use futures::StreamExt; + use schema::{merge::SchemaMerger, sort::SortKey, Schema}; + + use crate::{ + exec::{split::StreamSplitExec, Executor, ExecutorType}, + frontend::reorg::ReorgPlanner, + provider::{DeduplicateExec, RecordBatchesExec}, + test::TestChunk, + QueryChunk, + }; + + /// A macro to asserts the contents of the extracted metrics is reasonable + /// + macro_rules! assert_extracted_metrics { + ($EXTRACTED: expr, $EXPECTED_OUTPUT_ROWS: expr) => { + assert!( + $EXTRACTED.elapsed_compute.value() > 0, + "some elapsed compute time" + ); + assert_eq!( + $EXTRACTED.output_rows.value(), + $EXPECTED_OUTPUT_ROWS, + "expected output row count" + ); + + let start_ts = $EXTRACTED + .start_timestamp + .value() + .expect("start timestamp") + .timestamp_nanos_opt() + .expect("start timestamp in range"); + let end_ts = $EXTRACTED + .end_timestamp + .value() + .expect("end timestamp") + .timestamp_nanos_opt() + .expect("end timestamp in range"); + + assert!(start_ts > 0, "start timestamp was non zero"); + assert!(end_ts > 0, "end timestamp was non zero"); + assert!( + start_ts < end_ts, + "start timestamp was before end timestamp" + ); + }; + } + + #[tokio::test] + async fn test_metrics() { + test_helpers::maybe_start_logging(); + let (schema, chunks) = get_test_chunks(); + let sort_key = SortKey::from_columns(vec!["time", "tag1"]); + + // Use a split plan as it has StreamSplitExec, DeduplicateExec and IOxReadFilternode + let split_plan = ReorgPlanner::new() + .split_plan(Arc::from("t"), &schema, chunks, sort_key, vec![1000]) + .expect("created compact plan"); + + let executor = Executor::new_testing(); + let plan = executor + .new_context(ExecutorType::Reorg) + .create_physical_plan(&split_plan) + .await + .unwrap(); + + assert_eq!(plan.output_partitioning().partition_count(), 2); + + println!("Executing partition 0"); + let mut stream0 = test_execute_partition(Arc::clone(&plan), 0).await; + let mut num_rows = 0; + while let Some(batch) = stream0.next().await { + num_rows += batch.unwrap().num_rows(); + } + assert_eq!(num_rows, 3); + + println!("Executing partition 1"); + let mut stream1 = test_execute_partition(Arc::clone(&plan), 1).await; + let mut num_rows = 0; + while let Some(batch) = stream1.next().await { + num_rows += batch.unwrap().num_rows(); + } + assert_eq!(num_rows, 5); + + // now validate metrics are good + let extracted = extract_metrics(plan.as_ref(), |plan| { + plan.as_any().downcast_ref::().is_some() + }) + .unwrap(); + + assert_extracted_metrics!(extracted, 9); + + // now the deduplicator + let extracted = extract_metrics(plan.as_ref(), |plan| { + plan.as_any().downcast_ref::().is_some() + }) + .unwrap(); + + assert_extracted_metrics!(extracted, 3); + + // now the the split + let extracted = extract_metrics(plan.as_ref(), |plan| { + plan.as_any().downcast_ref::().is_some() + }) + .unwrap(); + + assert_extracted_metrics!(extracted, 8); + } + + // Extracted baseline metrics for the specified operator + #[derive(Debug)] + struct ExtractedMetrics { + elapsed_compute: metrics::Time, + output_rows: metrics::Count, + start_timestamp: metrics::Timestamp, + end_timestamp: metrics::Timestamp, + } + + // walks a plan tree, looking for the first plan node where a + // predicate returns true and extracts the common metrics + struct MetricsExtractor

+ where + P: FnMut(&dyn ExecutionPlan) -> bool, + { + pred: P, + inner: Option, + } + + impl

ExecutionPlanVisitor for MetricsExtractor

+ where + P: FnMut(&dyn ExecutionPlan) -> bool, + { + type Error = std::convert::Infallible; + + fn pre_visit( + &mut self, + plan: &dyn ExecutionPlan, + ) -> std::result::Result { + // not visiting this one + if !(self.pred)(plan) { + return Ok(true); + } + let metrics = plan.metrics().unwrap().aggregate_by_name(); + let mut elapsed_compute: Option = None; + let mut output_rows: Option = None; + let mut start_timestamp: Option = None; + let mut end_timestamp: Option = None; + + metrics.iter().for_each(|m| match m.value() { + MetricValue::ElapsedCompute(t) => { + assert!(elapsed_compute.is_none()); + elapsed_compute = Some(t.clone()) + } + MetricValue::OutputRows(c) => { + assert!(output_rows.is_none()); + output_rows = Some(c.clone()) + } + MetricValue::StartTimestamp(ts) => { + assert!(start_timestamp.is_none()); + start_timestamp = Some(ts.clone()) + } + MetricValue::EndTimestamp(ts) => { + assert!(end_timestamp.is_none()); + end_timestamp = Some(ts.clone()) + } + _ => {} + }); + + let new = ExtractedMetrics { + elapsed_compute: elapsed_compute.expect("did not find metric"), + output_rows: output_rows.expect("did not find metric"), + start_timestamp: start_timestamp.expect("did not find metric"), + end_timestamp: end_timestamp.expect("did not find metric"), + }; + + if let Some(existing) = &self.inner { + let ExtractedMetrics { + elapsed_compute, + output_rows, + start_timestamp, + end_timestamp, + } = existing; + new.elapsed_compute.add(elapsed_compute); + new.output_rows.add(output_rows.value()); + new.start_timestamp.update_to_min(start_timestamp); + new.end_timestamp.update_to_max(end_timestamp); + } + self.inner = Some(new); + + // found what we are looking for, no need to continue + Ok(false) + } + } + + fn extract_metrics

(plan: &dyn ExecutionPlan, pred: P) -> Option + where + P: FnMut(&dyn ExecutionPlan) -> bool, + { + let mut extractor = MetricsExtractor { pred, inner: None }; + + datafusion::physical_plan::accept(plan, &mut extractor).unwrap(); + + extractor.inner + } + + fn get_test_chunks() -> (Schema, Vec>) { + let max_time = 7000; + let chunk1 = Arc::new( + TestChunk::new("t") + .with_order(1) + .with_partition(1) + .with_time_column_with_stats(Some(50), Some(max_time)) + .with_tag_column_with_stats("tag1", Some("AL"), Some("MT")) + .with_i64_field_column("field_int") + .with_five_rows_of_data(), + ); + + // Chunk 2 has an extra field, and only 4 rows + let chunk2 = Arc::new( + TestChunk::new("t") + .with_order(2) + .with_partition(1) + .with_time_column_with_stats(Some(28000), Some(220000)) + .with_tag_column_with_stats("tag1", Some("UT"), Some("WA")) + .with_i64_field_column("field_int") + .with_i64_field_column("field_int2") + .with_may_contain_pk_duplicates(true) + .with_four_rows_of_data(), + ); + + let schema = SchemaMerger::new() + .merge(chunk1.schema()) + .unwrap() + .merge(chunk2.schema()) + .unwrap() + .build(); + + (schema, vec![chunk1, chunk2]) + } +} diff --git a/iox_query/src/frontend/reorg.rs b/iox_query/src/frontend/reorg.rs new file mode 100644 index 0000000..9bf8259 --- /dev/null +++ b/iox_query/src/frontend/reorg.rs @@ -0,0 +1,732 @@ +//! planning for physical reorganization operations (e.g. COMPACT) + +use std::sync::Arc; + +use datafusion::{logical_expr::LogicalPlan, prelude::col}; +use datafusion_util::lit_timestamptz_nano; +use observability_deps::tracing::debug; +use schema::{sort::SortKey, Schema, TIME_COLUMN_NAME}; + +use crate::{ + exec::make_stream_split, provider::ProviderBuilder, util::logical_sort_key_exprs, QueryChunk, +}; +use snafu::{ResultExt, Snafu}; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("Chunk schema not compatible for compact plan: {}", source))] + ChunkSchemaNotCompatible { source: schema::merge::Error }, + + #[snafu(display("Reorg planner got error building plan: {}", source))] + BuildingPlan { + source: datafusion::error::DataFusionError, + }, + + #[snafu(display( + "Reorg planner got error adding creating scan for {}: {}", + table_name, + source + ))] + CreatingScan { + table_name: String, + source: crate::provider::Error, + }, +} +pub type Result = std::result::Result; + +impl From for Error { + fn from(source: datafusion::error::DataFusionError) -> Self { + Self::BuildingPlan { source } + } +} + +/// Planner for physically rearranging chunk data. This planner +/// creates COMPACT and SPLIT plans for use in the database lifecycle manager +#[derive(Debug, Default, Copy, Clone)] +pub struct ReorgPlanner {} + +impl ReorgPlanner { + pub fn new() -> Self { + Self::default() + } + + /// Creates an execution plan for the COMPACT operations which does the following: + /// + /// 1. Merges chunks together into a single stream + /// 2. Deduplicates via PK as necessary + /// 3. Sorts the result according to the requested `output_sort_key` (if necessary) + /// + /// The plan looks like: + /// + /// ```text + /// (Optional Sort on output_sort_key) + /// (Scan chunks) <-- any needed deduplication happens here + /// ``` + pub fn compact_plan( + &self, + table_name: Arc, + schema: &Schema, + chunks: I, + output_sort_key: SortKey, + ) -> Result + where + I: IntoIterator>, + { + let mut builder = ProviderBuilder::new(Arc::clone(&table_name), schema.clone()) + .with_enable_deduplication(true); + + for chunk in chunks { + builder = builder.add_chunk(chunk); + } + + let provider = builder.build().context(CreatingScanSnafu { + table_name: table_name.as_ref(), + })?; + let plan_builder = Arc::new(provider) + .into_logical_plan_builder() + .context(BuildingPlanSnafu)?; + let sort_expr = logical_sort_key_exprs(&output_sort_key); + let plan = plan_builder + .sort(sort_expr) + .context(BuildingPlanSnafu)? + .build() + .context(BuildingPlanSnafu)?; + + debug!(table_name=table_name.as_ref(), plan=%plan.display_indent_schema(), + "created compact plan for table"); + + Ok(plan) + } + + /// Creates an execution plan for the SPLIT operations which does the following: + /// + /// 1. Merges chunks together into a single stream + /// 2. Deduplicates via PK as necessary + /// 3. Sorts the result according to the requested output_sort_key + /// 4. Splits the stream on value of the `time` column: Those + /// rows that are on or before the time and those that are after + /// + /// The plan looks like: + /// + /// ```text + /// (Split on Time) + /// (Sort on output_sort) + /// (Scan chunks) <-- any needed deduplication happens here + /// ``` + /// + /// The output execution plan has `N` "output streams" (DataFusion + /// partitions) where `N` = `split_times.len() + 1`. The + /// time ranges of the streams are: + /// + /// Stream 0: Rows that have `time` *on or before* the `split_times[0]` + /// + /// Stream i, where 0 < i < split_times.len(): + /// Rows have: `time` in range `(split_times[i-1], split_times[i]]`, + /// Which is: greater than `split_times[i-1]` up to and including `split_times[i]`. + /// + /// Stream n, where n = split_times.len()): Rows that have `time` + /// *after* `split_times[n-1]` as well as NULL rows + /// + /// # Panics + /// + /// The code will panic if split_times are not in monotonically increasing order + /// + /// # Example + /// if the input looks like: + /// ```text + /// X | time + /// ---+----- + /// b | 2000 + /// a | 1000 + /// c | 4000 + /// d | 2000 + /// e | 3000 + /// ``` + /// A split plan with `sort=time` and `split_times=[2000, 3000]` will produce the following three output streams + /// + /// ```text + /// X | time + /// ---+----- + /// a | 1000 + /// b | 2000 + /// d | 2000 + /// ``` + /// + /// ```text + /// X | time + /// ---+----- + /// e | 3000 + /// ``` + /// + /// ```text + /// X | time + /// ---+----- + /// c | 4000 + /// ``` + pub fn split_plan( + &self, + table_name: Arc, + schema: &Schema, + chunks: I, + output_sort_key: SortKey, + split_times: Vec, + ) -> Result + where + I: IntoIterator>, + { + // split_times must have values + if split_times.is_empty() { + panic!("Split plan does not accept empty split_times"); + } + + let mut builder = ProviderBuilder::new(Arc::clone(&table_name), schema.clone()) + .with_enable_deduplication(true); + + for chunk in chunks { + builder = builder.add_chunk(chunk); + } + + let provider = builder.build().context(CreatingScanSnafu { + table_name: table_name.as_ref(), + })?; + let plan_builder = Arc::new(provider) + .into_logical_plan_builder() + .context(BuildingPlanSnafu)?; + let sort_expr = logical_sort_key_exprs(&output_sort_key); + let plan = plan_builder + .sort(sort_expr) + .context(BuildingPlanSnafu)? + .build() + .context(BuildingPlanSnafu)?; + + let mut split_exprs = Vec::with_capacity(split_times.len()); + // time <= split_times[0] + split_exprs.push(col(TIME_COLUMN_NAME).lt_eq(lit_timestamptz_nano(split_times[0]))); + // split_times[i-1] , time <= split_time[i] + for i in 1..split_times.len() { + if split_times[i - 1] >= split_times[i] { + panic!( + "split_times[{}]: {} must be smaller than split_times[{}]: {}", + i - 1, + split_times[i - 1], + i, + split_times[i] + ); + } + split_exprs.push( + col(TIME_COLUMN_NAME) + .gt(lit_timestamptz_nano(split_times[i - 1])) + .and(col(TIME_COLUMN_NAME).lt_eq(lit_timestamptz_nano(split_times[i]))), + ); + } + let plan = make_stream_split(plan, split_exprs); + + debug!(table_name=table_name.as_ref(), plan=%plan.display_indent_schema(), + "created split plan for table"); + + Ok(plan) + } +} + +#[cfg(test)] +mod test { + use arrow_util::assert_batches_eq; + use datafusion_util::{test_collect, test_collect_partition}; + use schema::merge::SchemaMerger; + use schema::sort::SortKeyBuilder; + + use crate::{ + exec::{Executor, ExecutorType}, + test::{format_execution_plan, raw_data, TestChunk}, + }; + + use super::*; + + async fn get_test_chunks() -> (Schema, Vec>) { + // Chunk 1 with 5 rows of data on 2 tags + let chunk1 = Arc::new( + TestChunk::new("t") + .with_time_column_with_stats(Some(50), Some(7000)) + .with_tag_column_with_stats("tag1", Some("AL"), Some("MT")) + .with_i64_field_column("field_int") + .with_five_rows_of_data(), + ) as Arc; + + // Chunk 2 has an extra field, and only 4 fields + let chunk2 = Arc::new( + TestChunk::new("t") + .with_time_column_with_stats(Some(28000), Some(220000)) + .with_tag_column_with_stats("tag1", Some("UT"), Some("WA")) + .with_i64_field_column("field_int") + .with_i64_field_column("field_int2") + .with_may_contain_pk_duplicates(true) + .with_four_rows_of_data(), + ) as Arc; + + let expected = vec![ + "+-----------+------+--------------------------------+", + "| field_int | tag1 | time |", + "+-----------+------+--------------------------------+", + "| 1000 | MT | 1970-01-01T00:00:00.000001Z |", + "| 10 | MT | 1970-01-01T00:00:00.000007Z |", + "| 70 | CT | 1970-01-01T00:00:00.000000100Z |", + "| 100 | AL | 1970-01-01T00:00:00.000000050Z |", + "| 5 | MT | 1970-01-01T00:00:00.000005Z |", + "+-----------+------+--------------------------------+", + ]; + assert_batches_eq!(&expected, &raw_data(&[Arc::clone(&chunk1)]).await); + + let expected = vec![ + "+-----------+------------+------+-----------------------------+", + "| field_int | field_int2 | tag1 | time |", + "+-----------+------------+------+-----------------------------+", + "| 1000 | 1000 | WA | 1970-01-01T00:00:00.000028Z |", + "| 10 | 10 | VT | 1970-01-01T00:00:00.000210Z |", + "| 70 | 70 | UT | 1970-01-01T00:00:00.000220Z |", + "| 50 | 50 | VT | 1970-01-01T00:00:00.000210Z |", + "+-----------+------------+------+-----------------------------+", + ]; + assert_batches_eq!(&expected, &raw_data(&[Arc::clone(&chunk2)]).await); + + let schema = SchemaMerger::new() + .merge(chunk1.schema()) + .unwrap() + .merge(chunk2.schema()) + .unwrap() + .build(); + + (schema, vec![chunk1, chunk2]) + } + + async fn get_sorted_test_chunks() -> (Schema, Vec>) { + // Chunk 1 + let chunk1 = Arc::new( + TestChunk::new("t") + .with_time_column_with_stats(Some(1000), Some(1000)) + .with_tag_column_with_stats("tag1", Some("A"), Some("A")) + .with_i64_field_column("field_int") + .with_one_row_of_specific_data("A", 1, 1000), + ) as Arc; + + let expected = vec![ + "+-----------+------+-----------------------------+", + "| field_int | tag1 | time |", + "+-----------+------+-----------------------------+", + "| 1 | A | 1970-01-01T00:00:00.000001Z |", + "+-----------+------+-----------------------------+", + ]; + assert_batches_eq!(&expected, &raw_data(&[Arc::clone(&chunk1)]).await); + + // Chunk 2 + let chunk2 = Arc::new( + TestChunk::new("t") + .with_time_column_with_stats(Some(2000), Some(2000)) + .with_tag_column_with_stats("tag1", Some("B"), Some("B")) + .with_i64_field_column("field_int") + .with_one_row_of_specific_data("B", 2, 2000), + ) as Arc; + + let expected = vec![ + "+-----------+------+-----------------------------+", + "| field_int | tag1 | time |", + "+-----------+------+-----------------------------+", + "| 2 | B | 1970-01-01T00:00:00.000002Z |", + "+-----------+------+-----------------------------+", + ]; + assert_batches_eq!(&expected, &raw_data(&[Arc::clone(&chunk2)]).await); + + (chunk1.schema().clone(), vec![chunk1, chunk2]) + } + + #[tokio::test] + async fn test_compact_plan_sorted() { + test_helpers::maybe_start_logging(); + + // ensures that the output is actually sorted + // https://github.com/influxdata/influxdb_iox/issues/6125 + let (schema, chunks) = get_sorted_test_chunks().await; + + let chunk_orders = vec![ + // reverse order + vec![Arc::clone(&chunks[1]), Arc::clone(&chunks[0])], + chunks, + ]; + + // executor has only 1 thread + let executor = Executor::new_testing(); + for chunks in chunk_orders { + let sort_key = SortKeyBuilder::with_capacity(2) + .with_col_opts("tag1", false, true) + .with_col_opts(TIME_COLUMN_NAME, false, true) + .build(); + + let compact_plan = ReorgPlanner::new() + .compact_plan(Arc::from("t"), &schema, chunks, sort_key) + .expect("created compact plan"); + + let physical_plan = executor + .new_context(ExecutorType::Reorg) + .create_physical_plan(&compact_plan) + .await + .unwrap(); + + let batches = test_collect(physical_plan).await; + + // should be sorted on tag1 then timestamp + let expected = vec![ + "+-----------+------+-----------------------------+", + "| field_int | tag1 | time |", + "+-----------+------+-----------------------------+", + "| 1 | A | 1970-01-01T00:00:00.000001Z |", + "| 2 | B | 1970-01-01T00:00:00.000002Z |", + "+-----------+------+-----------------------------+", + ]; + + assert_batches_eq!(&expected, &batches); + } + } + + #[tokio::test] + async fn test_compact_plan_default_sort() { + test_helpers::maybe_start_logging(); + + let (schema, chunks) = get_test_chunks().await; + + let sort_key = SortKeyBuilder::with_capacity(2) + .with_col("tag1") + .with_col(TIME_COLUMN_NAME) + .build(); + + let compact_plan = ReorgPlanner::new() + .compact_plan(Arc::from("t"), &schema, chunks, sort_key) + .expect("created compact plan"); + + let executor = Executor::new_testing(); + let physical_plan = executor + .new_context(ExecutorType::Reorg) + .create_physical_plan(&compact_plan) + .await + .unwrap(); + + // It is critical that the plan only sorts the inputs and is not resorted after the UnionExec. + insta::assert_yaml_snapshot!( + format_execution_plan(&physical_plan), + @r###" + --- + - " SortPreservingMergeExec: [tag1@2 ASC,time@3 ASC]" + - " UnionExec" + - " SortExec: expr=[tag1@2 ASC,time@3 ASC]" + - " RecordBatchesExec: chunks=1, projection=[field_int, field_int2, tag1, time]" + - " ProjectionExec: expr=[field_int@1 as field_int, field_int2@2 as field_int2, tag1@3 as tag1, time@4 as time]" + - " DeduplicateExec: [tag1@3 ASC,time@4 ASC]" + - " SortExec: expr=[tag1@3 ASC,time@4 ASC,__chunk_order@0 ASC]" + - " RecordBatchesExec: chunks=1, projection=[__chunk_order, field_int, field_int2, tag1, time]" + "### + ); + + assert_eq!( + physical_plan.output_partitioning().partition_count(), + 1, + "{:?}", + physical_plan.output_partitioning() + ); + + let batches = test_collect(physical_plan).await; + + // sorted on state ASC and time ASC (defaults) + let expected = vec![ + "+-----------+------------+------+--------------------------------+", + "| field_int | field_int2 | tag1 | time |", + "+-----------+------------+------+--------------------------------+", + "| 100 | | AL | 1970-01-01T00:00:00.000000050Z |", + "| 70 | | CT | 1970-01-01T00:00:00.000000100Z |", + "| 1000 | | MT | 1970-01-01T00:00:00.000001Z |", + "| 5 | | MT | 1970-01-01T00:00:00.000005Z |", + "| 10 | | MT | 1970-01-01T00:00:00.000007Z |", + "| 70 | 70 | UT | 1970-01-01T00:00:00.000220Z |", + "| 50 | 50 | VT | 1970-01-01T00:00:00.000210Z |", + "| 1000 | 1000 | WA | 1970-01-01T00:00:00.000028Z |", + "+-----------+------------+------+--------------------------------+", + ]; + + assert_batches_eq!(&expected, &batches); + } + + #[tokio::test] + async fn test_compact_plan_alternate_sort() { + test_helpers::maybe_start_logging(); + + let (schema, chunks) = get_test_chunks().await; + + let sort_key = SortKeyBuilder::with_capacity(2) + // use something other than the default sort + .with_col_opts("tag1", true, true) + .with_col_opts(TIME_COLUMN_NAME, false, false) + .build(); + + let compact_plan = ReorgPlanner::new() + .compact_plan(Arc::from("t"), &schema, chunks, sort_key) + .expect("created compact plan"); + + let executor = Executor::new_testing(); + let physical_plan = executor + .new_context(ExecutorType::Reorg) + .create_physical_plan(&compact_plan) + .await + .unwrap(); + + insta::assert_yaml_snapshot!( + format_execution_plan(&physical_plan), + @r###" + --- + - " SortPreservingMergeExec: [tag1@2 DESC,time@3 ASC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[tag1@2 DESC,time@3 ASC NULLS LAST]" + - " RecordBatchesExec: chunks=1, projection=[field_int, field_int2, tag1, time]" + - " SortExec: expr=[tag1@2 DESC,time@3 ASC NULLS LAST]" + - " ProjectionExec: expr=[field_int@1 as field_int, field_int2@2 as field_int2, tag1@3 as tag1, time@4 as time]" + - " DeduplicateExec: [tag1@3 ASC,time@4 ASC]" + - " SortExec: expr=[tag1@3 ASC,time@4 ASC,__chunk_order@0 ASC]" + - " RecordBatchesExec: chunks=1, projection=[__chunk_order, field_int, field_int2, tag1, time]" + "### + ); + + assert_eq!( + physical_plan.output_partitioning().partition_count(), + 1, + "{:?}", + physical_plan.output_partitioning() + ); + + let batches = test_collect(physical_plan).await; + + // sorted on state DESC and time ASC + let expected = vec![ + "+-----------+------------+------+--------------------------------+", + "| field_int | field_int2 | tag1 | time |", + "+-----------+------------+------+--------------------------------+", + "| 1000 | 1000 | WA | 1970-01-01T00:00:00.000028Z |", + "| 50 | 50 | VT | 1970-01-01T00:00:00.000210Z |", + "| 70 | 70 | UT | 1970-01-01T00:00:00.000220Z |", + "| 1000 | | MT | 1970-01-01T00:00:00.000001Z |", + "| 5 | | MT | 1970-01-01T00:00:00.000005Z |", + "| 10 | | MT | 1970-01-01T00:00:00.000007Z |", + "| 70 | | CT | 1970-01-01T00:00:00.000000100Z |", + "| 100 | | AL | 1970-01-01T00:00:00.000000050Z |", + "+-----------+------------+------+--------------------------------+", + ]; + + assert_batches_eq!(&expected, &batches); + } + + #[tokio::test] + async fn test_split_plan() { + test_helpers::maybe_start_logging(); + // validate that the plumbing is all hooked up. The logic of + // the operator is tested in its own module. + let (schema, chunks) = get_test_chunks().await; + + let sort_key = SortKeyBuilder::with_capacity(2) + .with_col_opts("time", false, false) + .with_col_opts("tag1", false, true) + .build(); + + // split on 1000 should have timestamps 1000, 5000, and 7000 + let split_plan = ReorgPlanner::new() + .split_plan(Arc::from("t"), &schema, chunks, sort_key, vec![1000]) + .expect("created compact plan"); + + let executor = Executor::new_testing(); + let physical_plan = executor + .new_context(ExecutorType::Reorg) + .create_physical_plan(&split_plan) + .await + .unwrap(); + + insta::assert_yaml_snapshot!( + format_execution_plan(&physical_plan), + @r###" + --- + - " StreamSplitExec" + - " SortPreservingMergeExec: [time@3 ASC NULLS LAST,tag1@2 ASC]" + - " UnionExec" + - " SortExec: expr=[time@3 ASC NULLS LAST,tag1@2 ASC]" + - " RecordBatchesExec: chunks=1, projection=[field_int, field_int2, tag1, time]" + - " SortExec: expr=[time@3 ASC NULLS LAST,tag1@2 ASC]" + - " ProjectionExec: expr=[field_int@1 as field_int, field_int2@2 as field_int2, tag1@3 as tag1, time@4 as time]" + - " DeduplicateExec: [tag1@3 ASC,time@4 ASC]" + - " SortExec: expr=[tag1@3 ASC,time@4 ASC,__chunk_order@0 ASC]" + - " RecordBatchesExec: chunks=1, projection=[__chunk_order, field_int, field_int2, tag1, time]" + "### + ); + + assert_eq!( + physical_plan.output_partitioning().partition_count(), + 2, + "{:?}", + physical_plan.output_partitioning() + ); + + // verify that the stream was split + let batches0 = test_collect_partition(Arc::clone(&physical_plan), 0).await; + + // Note sorted on time + let expected = vec![ + "+-----------+------------+------+--------------------------------+", + "| field_int | field_int2 | tag1 | time |", + "+-----------+------------+------+--------------------------------+", + "| 100 | | AL | 1970-01-01T00:00:00.000000050Z |", + "| 70 | | CT | 1970-01-01T00:00:00.000000100Z |", + "| 1000 | | MT | 1970-01-01T00:00:00.000001Z |", + "+-----------+------------+------+--------------------------------+", + ]; + assert_batches_eq!(&expected, &batches0); + + let batches1 = test_collect_partition(physical_plan, 1).await; + + // Sorted on time + let expected = vec![ + "+-----------+------------+------+-----------------------------+", + "| field_int | field_int2 | tag1 | time |", + "+-----------+------------+------+-----------------------------+", + "| 5 | | MT | 1970-01-01T00:00:00.000005Z |", + "| 10 | | MT | 1970-01-01T00:00:00.000007Z |", + "| 1000 | 1000 | WA | 1970-01-01T00:00:00.000028Z |", + "| 50 | 50 | VT | 1970-01-01T00:00:00.000210Z |", + "| 70 | 70 | UT | 1970-01-01T00:00:00.000220Z |", + "+-----------+------------+------+-----------------------------+", + ]; + + assert_batches_eq!(&expected, &batches1); + } + + #[tokio::test] + async fn test_split_plan_multi_exps() { + test_helpers::maybe_start_logging(); + // validate that the plumbing is all hooked up. The logic of + // the operator is tested in its own module. + let (schema, chunks) = get_test_chunks().await; + + let sort_key = SortKeyBuilder::with_capacity(2) + .with_col_opts("time", false, false) + .with_col_opts("tag1", false, true) + .build(); + + // split on 1000 and 7000 + let split_plan = ReorgPlanner::new() + .split_plan(Arc::from("t"), &schema, chunks, sort_key, vec![1000, 7000]) + .expect("created compact plan"); + + let executor = Executor::new_testing(); + let physical_plan = executor + .new_context(ExecutorType::Reorg) + .create_physical_plan(&split_plan) + .await + .unwrap(); + + insta::assert_yaml_snapshot!( + format_execution_plan(&physical_plan), + @r###" + --- + - " StreamSplitExec" + - " SortPreservingMergeExec: [time@3 ASC NULLS LAST,tag1@2 ASC]" + - " UnionExec" + - " SortExec: expr=[time@3 ASC NULLS LAST,tag1@2 ASC]" + - " RecordBatchesExec: chunks=1, projection=[field_int, field_int2, tag1, time]" + - " SortExec: expr=[time@3 ASC NULLS LAST,tag1@2 ASC]" + - " ProjectionExec: expr=[field_int@1 as field_int, field_int2@2 as field_int2, tag1@3 as tag1, time@4 as time]" + - " DeduplicateExec: [tag1@3 ASC,time@4 ASC]" + - " SortExec: expr=[tag1@3 ASC,time@4 ASC,__chunk_order@0 ASC]" + - " RecordBatchesExec: chunks=1, projection=[__chunk_order, field_int, field_int2, tag1, time]" + "### + ); + + assert_eq!( + physical_plan.output_partitioning().partition_count(), + 3, + "{:?}", + physical_plan.output_partitioning() + ); + + // Verify that the stream was split + + // Note sorted on time + // Should include time <= 1000 + let batches0 = test_collect_partition(Arc::clone(&physical_plan), 0).await; + let expected = vec![ + "+-----------+------------+------+--------------------------------+", + "| field_int | field_int2 | tag1 | time |", + "+-----------+------------+------+--------------------------------+", + "| 100 | | AL | 1970-01-01T00:00:00.000000050Z |", + "| 70 | | CT | 1970-01-01T00:00:00.000000100Z |", + "| 1000 | | MT | 1970-01-01T00:00:00.000001Z |", + "+-----------+------------+------+--------------------------------+", + ]; + assert_batches_eq!(&expected, &batches0); + + // Sorted on time + // Should include 1000 < time <= 7000 + let batches1 = test_collect_partition(Arc::clone(&physical_plan), 1).await; + let expected = vec![ + "+-----------+------------+------+-----------------------------+", + "| field_int | field_int2 | tag1 | time |", + "+-----------+------------+------+-----------------------------+", + "| 5 | | MT | 1970-01-01T00:00:00.000005Z |", + "| 10 | | MT | 1970-01-01T00:00:00.000007Z |", + "+-----------+------------+------+-----------------------------+", + ]; + assert_batches_eq!(&expected, &batches1); + + // Sorted on time + // Should include 7000 < time + let batches2 = test_collect_partition(physical_plan, 2).await; + let expected = vec![ + "+-----------+------------+------+-----------------------------+", + "| field_int | field_int2 | tag1 | time |", + "+-----------+------------+------+-----------------------------+", + "| 1000 | 1000 | WA | 1970-01-01T00:00:00.000028Z |", + "| 50 | 50 | VT | 1970-01-01T00:00:00.000210Z |", + "| 70 | 70 | UT | 1970-01-01T00:00:00.000220Z |", + "+-----------+------------+------+-----------------------------+", + ]; + assert_batches_eq!(&expected, &batches2); + } + + #[tokio::test] + #[should_panic(expected = "Split plan does not accept empty split_times")] + async fn test_split_plan_panic_empty() { + test_helpers::maybe_start_logging(); + // validate that the plumbing is all hooked up. The logic of + // the operator is tested in its own module. + let (schema, chunks) = get_test_chunks().await; + + let sort_key = SortKeyBuilder::with_capacity(2) + .with_col_opts("time", false, false) + .with_col_opts("tag1", false, true) + .build(); + + // split on 1000 and 7000 + let _split_plan = ReorgPlanner::new() + .split_plan(Arc::from("t"), &schema, chunks, sort_key, vec![]) // reason of panic: empty split_times + .expect("created compact plan"); + } + + #[tokio::test] + #[should_panic(expected = "split_times[0]: 1000 must be smaller than split_times[1]: 500")] + async fn test_split_plan_panic_times() { + test_helpers::maybe_start_logging(); + // validate that the plumbing is all hooked up. The logic of + // the operator is tested in its own module. + let (schema, chunks) = get_test_chunks().await; + + let sort_key = SortKeyBuilder::with_capacity(2) + .with_col_opts("time", false, false) + .with_col_opts("tag1", false, true) + .build(); + + // split on 1000 and 7000 + let _split_plan = ReorgPlanner::new() + .split_plan(Arc::from("t"), &schema, chunks, sort_key, vec![1000, 500]) // reason of panic: split_times not in ascending order + .expect("created compact plan"); + } +} diff --git a/iox_query/src/frontend/sql.rs b/iox_query/src/frontend/sql.rs new file mode 100644 index 0000000..4008e3c --- /dev/null +++ b/iox_query/src/frontend/sql.rs @@ -0,0 +1,26 @@ +use std::sync::Arc; + +use crate::exec::context::IOxSessionContext; +use datafusion::{common::ParamValues, error::Result, physical_plan::ExecutionPlan}; + +/// This struct can create plans for running SQL queries against databases +#[derive(Debug, Default, Copy, Clone)] +pub struct SqlQueryPlanner {} + +impl SqlQueryPlanner { + pub fn new() -> Self { + Self::default() + } + + /// Plan a SQL query against the catalogs registered with `ctx`, and return a + /// DataFusion physical execution plan that runs on the query executor. + pub async fn query( + &self, + query: &str, + params: impl Into + Send, + ctx: &IOxSessionContext, + ) -> Result> { + let ctx = ctx.child_ctx("SqlQueryPlanner::query"); + ctx.sql_to_physical_plan_with_params(query, params).await + } +} diff --git a/iox_query/src/lib.rs b/iox_query/src/lib.rs new file mode 100644 index 0000000..e5afb92 --- /dev/null +++ b/iox_query/src/lib.rs @@ -0,0 +1,227 @@ +//! Contains the IOx query engine +#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)] +#![warn( + missing_debug_implementations, + clippy::explicit_iter_loop, + clippy::use_self, + clippy::clone_on_ref_ptr, + // See https://github.com/influxdata/influxdb_iox/pull/1671 + clippy::future_not_send, + clippy::todo, + clippy::dbg_macro, + unused_crate_dependencies +)] +#![allow(unreachable_pub)] + +use datafusion_util::MemoryStream; +use futures::TryStreamExt; +use query_log::{QueryCompletedToken, QueryText, StateReceived}; +use trace::{ctx::SpanContext, span::Span}; + +use tracker::InstrumentedAsyncOwnedSemaphorePermit; +// Workaround for "unused crate" lint false positives. +use workspace_hack as _; + +use arrow::{ + datatypes::{DataType, Field, SchemaRef}, + record_batch::RecordBatch, +}; +use async_trait::async_trait; +use data_types::{ChunkId, ChunkOrder, TransitionPartitionId}; +use datafusion::{ + error::DataFusionError, + physical_plan::{SendableRecordBatchStream, Statistics}, + prelude::{Expr, SessionContext}, +}; +use exec::IOxSessionContext; +use once_cell::sync::Lazy; +use parquet_file::storage::ParquetExecInput; +use schema::{sort::SortKey, Projection, Schema}; +use std::{any::Any, fmt::Debug, sync::Arc}; + +pub mod chunk_statistics; +pub mod config; +pub mod exec; +pub mod frontend; +pub mod logical_optimizer; +pub mod physical_optimizer; +pub mod plan; +pub mod provider; +pub mod pruning; +pub mod query_log; +pub mod statistics; +pub mod util; + +pub use query_functions::group_by::{Aggregate, WindowDuration}; + +/// The name of the virtual column that represents the chunk order. +pub const CHUNK_ORDER_COLUMN_NAME: &str = "__chunk_order"; + +static CHUNK_ORDER_FIELD: Lazy> = + Lazy::new(|| Arc::new(Field::new(CHUNK_ORDER_COLUMN_NAME, DataType::Int64, false))); + +/// Generate [`Field`] for [chunk order column](CHUNK_ORDER_COLUMN_NAME). +pub fn chunk_order_field() -> Arc { + Arc::clone(&CHUNK_ORDER_FIELD) +} + +/// A single chunk of data. +pub trait QueryChunk: Debug + Send + Sync + 'static { + /// Return a statistics of the data + fn stats(&self) -> Arc; + + /// return a reference to the summary of the data held in this chunk + fn schema(&self) -> &Schema; + + /// Return partition identifier for this chunk + fn partition_id(&self) -> &TransitionPartitionId; + + /// return a reference to the sort key if any + fn sort_key(&self) -> Option<&SortKey>; + + /// returns the Id of this chunk. Ids are unique within a + /// particular partition. + fn id(&self) -> ChunkId; + + /// Returns true if the chunk may contain a duplicate "primary + /// key" within itself + fn may_contain_pk_duplicates(&self) -> bool; + + /// Provides access to raw [`QueryChunk`] data. + /// + /// The engine assume that minimal work shall be performed to gather the `QueryChunkData`. + fn data(&self) -> QueryChunkData; + + /// Returns chunk type. Useful in tests and debug logs. + fn chunk_type(&self) -> &str; + + /// Order of this chunk relative to other overlapping chunks. + fn order(&self) -> ChunkOrder; + + /// Return backend as [`Any`] which can be used to downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; +} + +/// `QueryNamespace` is the main trait implemented by the IOx subsystems that store actual data. +/// +/// Namespaces store data organized by partitions and each partition stores data in Chunks. +#[async_trait] +pub trait QueryNamespace: Debug + Send + Sync { + /// Returns a set of chunks within the partition with data that may match the provided + /// filter expression. + /// + /// If possible, chunks which have no rows that can possibly match the filter may be omitted. + /// + /// If projection is `None`, returned chunks will include all columns of its original data. + /// Otherwise, returned chunks will include PK columns (tags and time) and columns specified in + /// the projection. Projecting chunks here is optional and a mere optimization. The query + /// subsystem does NOT rely on it. + async fn chunks( + &self, + table_name: &str, + filters: &[Expr], + projection: Option<&Vec>, + ctx: IOxSessionContext, + ) -> Result>, DataFusionError>; + + /// Retention cutoff time. + /// + /// This gives the timestamp (NOT the duration) at which data should be cut off. This should result in an additional + /// filter of the following form: + /// + /// ```text + /// time >= retention_time_ns + /// ``` + /// + /// Returns `None` if now retention policy was defined. + fn retention_time_ns(&self) -> Option; + + /// Record that particular type of query was run / planned + fn record_query( + &self, + span_ctx: Option<&SpanContext>, + query_type: &'static str, + query_text: QueryText, + ) -> QueryCompletedToken; + + /// Returns a new execution context suitable for running queries + fn new_query_context(&self, span_ctx: Option) -> IOxSessionContext; +} + +/// Trait that allows the query engine (which includes flight and storage/InfluxRPC) to access a +/// virtual set of namespaces. +/// +/// This is the only entry point for the query engine. This trait and the traits reachable by it (e.g. +/// [`QueryNamespace`]) are the only wait to access the catalog and payload data. +#[async_trait] +pub trait QueryNamespaceProvider: std::fmt::Debug + Send + Sync + 'static { + /// Get namespace if it exists. + /// + /// System tables may contain debug information depending on `include_debug_info_tables`. + async fn db( + &self, + name: &str, + span: Option, + include_debug_info_tables: bool, + ) -> Option>; + + /// Acquire concurrency-limiting sempahore + async fn acquire_semaphore(&self, span: Option) -> InstrumentedAsyncOwnedSemaphorePermit; +} + +/// Raw data of a [`QueryChunk`]. +pub enum QueryChunkData { + /// Record batches. + RecordBatches(SendableRecordBatchStream), + + /// Parquet file. + /// + /// See [`ParquetExecInput`] for details. + Parquet(ParquetExecInput), +} + +impl QueryChunkData { + /// Read data into [`RecordBatch`]es. This is mostly meant for testing! + pub async fn read_to_batches( + self, + schema: &Schema, + session_ctx: &SessionContext, + ) -> Vec { + match self { + Self::RecordBatches(batches) => batches.try_collect::>().await.unwrap(), + Self::Parquet(exec_input) => exec_input + .read_to_batches(schema.as_arrow(), Projection::All, session_ctx) + .await + .unwrap(), + } + } + + /// Create data based on batches and schema. + pub fn in_mem(batches: Vec, schema: SchemaRef) -> Self { + let s = MemoryStream::new_with_schema(batches, schema); + let s: SendableRecordBatchStream = Box::pin(s); + Self::RecordBatches(s) + } +} + +impl std::fmt::Debug for QueryChunkData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::RecordBatches(_) => f.debug_tuple("RecordBatches").field(&"").finish(), + Self::Parquet(input) => f.debug_tuple("Parquet").field(input).finish(), + } + } +} + +// Note: I would like to compile this module only in the 'test' cfg, +// but when I do so then other modules can not find them. For example: +// +// error[E0433]: failed to resolve: could not find `test` in `storage` +// --> src/server/mutable_buffer_routes.rs:353:19 +// | +// 353 | use iox_query::test::TestDatabaseStore; +// | ^^^^ could not find `test` in `query` + +// +//#[cfg(test)] +pub mod test; diff --git a/iox_query/src/logical_optimizer/extract_sleep.rs b/iox_query/src/logical_optimizer/extract_sleep.rs new file mode 100644 index 0000000..2f11446 --- /dev/null +++ b/iox_query/src/logical_optimizer/extract_sleep.rs @@ -0,0 +1,100 @@ +use std::sync::Arc; + +use datafusion::logical_expr::expr::ScalarFunction; +use datafusion::{ + common::{tree_node::TreeNodeRewriter, DFSchema}, + error::DataFusionError, + logical_expr::{expr_rewriter::rewrite_preserving_name, Extension, LogicalPlan}, + optimizer::{OptimizerConfig, OptimizerRule}, + prelude::{lit, Expr}, + scalar::ScalarValue, +}; +use query_functions::SLEEP_UDF_NAME; + +use crate::exec::sleep::SleepNode; + +/// Rewrites the ["sleep" UDF](SLEEP_UDF_NAME) to a NULL expression and a [`SleepNode`]. +/// +/// See [`crate::exec::sleep`] for more details. +#[derive(Debug, Clone)] +pub struct ExtractSleep {} + +impl ExtractSleep { + /// Create new optimizer rule. + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for ExtractSleep { + fn name(&self) -> &str { + "extract_sleep" + } + + fn try_optimize( + &self, + plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> datafusion::error::Result> { + optimize(plan).map(Some) + } +} + +fn optimize(plan: &LogicalPlan) -> Result { + let new_inputs = plan + .inputs() + .iter() + .map(|input| optimize(input)) + .collect::, DataFusionError>>()?; + + let mut schema = + new_inputs + .iter() + .map(|input| input.schema()) + .fold(DFSchema::empty(), |mut lhs, rhs| { + lhs.merge(rhs); + lhs + }); + + schema.merge(plan.schema()); + + let mut expr_rewriter = Rewriter::default(); + + let new_exprs = plan + .expressions() + .into_iter() + .map(|expr| rewrite_preserving_name(expr, &mut expr_rewriter)) + .collect::, DataFusionError>>()?; + let mut plan = plan.with_new_exprs(new_exprs, &new_inputs)?; + + if !expr_rewriter.found_exprs.is_empty() { + plan = LogicalPlan::Extension(Extension { + node: Arc::new(SleepNode::new(plan, expr_rewriter.found_exprs)), + }); + } + + Ok(plan) +} + +#[derive(Default)] +struct Rewriter { + found_exprs: Vec, +} + +impl TreeNodeRewriter for Rewriter { + type N = Expr; + + fn mutate(&mut self, expr: Expr) -> Result { + match expr { + Expr::ScalarFunction(ScalarFunction { func_def, mut args }) => { + if func_def.name() == SLEEP_UDF_NAME { + self.found_exprs.append(&mut args); + return Ok(lit(ScalarValue::Null)); + } + + Ok(Expr::ScalarFunction(ScalarFunction { func_def, args })) + } + _ => Ok(expr), + } + } +} diff --git a/iox_query/src/logical_optimizer/handle_gapfill.rs b/iox_query/src/logical_optimizer/handle_gapfill.rs new file mode 100644 index 0000000..bd046b1 --- /dev/null +++ b/iox_query/src/logical_optimizer/handle_gapfill.rs @@ -0,0 +1,1176 @@ +//! An optimizer rule that transforms a plan +//! to fill gaps in time series data. + +pub mod range_predicate; + +use crate::exec::gapfill::{FillStrategy, GapFill, GapFillParams}; +use datafusion::logical_expr::ScalarFunctionDefinition; +use datafusion::{ + common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter, VisitRecursion}, + error::{DataFusionError, Result}, + logical_expr::{ + expr::{Alias, ScalarFunction}, + utils::expr_to_columns, + Aggregate, BuiltinScalarFunction, Extension, LogicalPlan, Projection, + }, + optimizer::{optimizer::ApplyOrder, OptimizerConfig, OptimizerRule}, + prelude::{col, Column, Expr}, +}; +use hashbrown::{hash_map, HashMap}; +use query_functions::gapfill::{DATE_BIN_GAPFILL_UDF_NAME, INTERPOLATE_UDF_NAME, LOCF_UDF_NAME}; +use std::{ + collections::HashSet, + ops::{Bound, Range}, + sync::Arc, +}; + +/// This optimizer rule enables gap-filling semantics for SQL queries +/// that contain calls to `DATE_BIN_GAPFILL()` and related functions +/// like `LOCF()`. +/// +/// In SQL a typical gap-filling query might look like this: +/// ```sql +/// SELECT +/// location, +/// DATE_BIN_GAPFILL(INTERVAL '1 minute', time, '1970-01-01T00:00:00Z') AS minute, +/// LOCF(AVG(temp)) +/// FROM temps +/// WHERE time > NOW() - INTERVAL '6 hours' AND time < NOW() +/// GROUP BY LOCATION, MINUTE +/// ``` +/// +/// The initial logical plan will look like this: +/// +/// ```text +/// Projection: location, date_bin_gapfill(...) as minute, LOCF(AVG(temps.temp)) +/// Aggregate: groupBy=[[location, date_bin_gapfill(...)]], aggr=[[AVG(temps.temp)]] +/// ... +/// ``` +/// +/// This optimizer rule transforms it to this: +/// +/// ```text +/// Projection: location, date_bin_gapfill(...) as minute, AVG(temps.temp) +/// GapFill: groupBy=[[location, date_bin_gapfill(...))]], aggr=[[LOCF(AVG(temps.temp))]], start=..., stop=... +/// Aggregate: groupBy=[[location, date_bin(...))]], aggr=[[AVG(temps.temp)]] +/// ... +/// ``` +/// +/// For `Aggregate` nodes that contain calls to `DATE_BIN_GAPFILL`, this rule will: +/// - Convert `DATE_BIN_GAPFILL()` to `DATE_BIN()` +/// - Create a `GapFill` node that fills in gaps in the query +/// - The range for gap filling is found by analyzing any preceding `Filter` nodes +/// +/// If there is a `Projection` above the `GapFill` node that gets created: +/// - Look for calls to gap-filling functions like `LOCF` +/// - Push down these functions into the `GapFill` node, updating the fill strategy for the column. +/// +/// Note: both `DATE_BIN_GAPFILL` and `LOCF` are functions that don't have implementations. +/// This rule must rewrite the plan to get rid of them. +pub struct HandleGapFill; + +impl HandleGapFill { + pub fn new() -> Self { + Self {} + } +} + +impl Default for HandleGapFill { + fn default() -> Self { + Self::new() + } +} + +impl OptimizerRule for HandleGapFill { + fn try_optimize( + &self, + plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + handle_gap_fill(plan) + } + + fn name(&self) -> &str { + "handle_gap_fill" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } +} + +fn handle_gap_fill(plan: &LogicalPlan) -> Result> { + let res = match plan { + LogicalPlan::Aggregate(aggr) => { + handle_aggregate(aggr).map_err(|e| e.context("handle_aggregate"))? + } + LogicalPlan::Projection(proj) => { + handle_projection(proj).map_err(|e| e.context("handle_projection"))? + } + _ => None, + }; + + if res.is_none() { + // no transformation was applied, + // so make sure the plan is not using gap filling + // functions in an unsupported way. + check_node(plan)?; + } + + Ok(res) +} + +fn handle_aggregate(aggr: &Aggregate) -> Result> { + let Aggregate { + input, + group_expr, + aggr_expr, + schema, + .. + } = aggr; + + // new_group_expr has DATE_BIN_GAPFILL replaced with DATE_BIN. + let RewriteInfo { + new_group_expr, + date_bin_gapfill_index, + date_bin_gapfill_args, + } = if let Some(v) = + replace_date_bin_gapfill(group_expr).map_err(|e| e.context("replace_date_bin_gapfill"))? + { + v + } else { + return Ok(None); + }; + + let new_aggr_plan = { + // Create the aggregate node with the same output schema as the orignal + // one. This means that there will be an output column called `date_bin_gapfill(...)` + // even though the actual expression populating that column will be `date_bin(...)`. + // This seems acceptable since it avoids having to deal with renaming downstream. + let new_aggr_plan = Aggregate::try_new_with_schema( + Arc::clone(input), + new_group_expr, + aggr_expr.clone(), + Arc::clone(schema), + ) + .map_err(|e| e.context("Aggregate::try_new_with_schema"))?; + let new_aggr_plan = LogicalPlan::Aggregate(new_aggr_plan); + check_node(&new_aggr_plan).map_err(|e| e.context("check_node"))?; + new_aggr_plan + }; + + let new_gap_fill_plan = + build_gapfill_node(new_aggr_plan, date_bin_gapfill_index, date_bin_gapfill_args) + .map_err(|e| e.context("build_gapfill_node"))?; + Ok(Some(new_gap_fill_plan)) +} + +fn build_gapfill_node( + new_aggr_plan: LogicalPlan, + date_bin_gapfill_index: usize, + date_bin_gapfill_args: Vec, +) -> Result { + match date_bin_gapfill_args.len() { + 2 | 3 => (), + nargs => { + return Err(DataFusionError::Plan(format!( + "DATE_BIN_GAPFILL expects 2 or 3 arguments, got {nargs}", + ))); + } + } + + let mut args_iter = date_bin_gapfill_args.into_iter(); + + // Ensure that stride argument is a scalar + let stride = args_iter.next().unwrap(); + validate_scalar_expr("stride argument to DATE_BIN_GAPFILL", &stride) + .map_err(|e| e.context("validate_scalar_expr"))?; + + fn get_column(expr: Expr) -> Result { + match expr { + Expr::Column(c) => Ok(c), + Expr::Cast(c) => get_column(*c.expr), + _ => Err(DataFusionError::Plan( + "DATE_BIN_GAPFILL requires a column as the source argument".to_string(), + )), + } + } + + // Ensure that the source argument is a column + let time_col = + get_column(args_iter.next().unwrap()).map_err(|e| e.context("get time column"))?; + + // Ensure that a time range was specified and is valid for gap filling + let time_range = range_predicate::find_time_range(new_aggr_plan.inputs()[0], &time_col) + .map_err(|e| e.context("find time range"))?; + validate_time_range(&time_range).map_err(|e| e.context("validate time range"))?; + + // Ensure that origin argument is a scalar + let origin = args_iter.next(); + if let Some(ref origin) = origin { + validate_scalar_expr("origin argument to DATE_BIN_GAPFILL", origin) + .map_err(|e| e.context("validate origin"))?; + } + + // Make sure the time output to the gapfill node matches what the + // aggregate output was. + let time_column = + col(new_aggr_plan.schema().fields()[date_bin_gapfill_index].qualified_column()); + + let LogicalPlan::Aggregate(aggr) = &new_aggr_plan else { + return Err(DataFusionError::Internal(format!( + "Expected Aggregate plan, got {}", + new_aggr_plan.display() + ))); + }; + let mut new_group_expr: Vec<_> = aggr + .schema + .fields() + .iter() + .map(|f| Expr::Column(f.qualified_column())) + .collect(); + let aggr_expr = new_group_expr.split_off(aggr.group_expr.len()); + + let fill_behavior = aggr_expr + .iter() + .cloned() + .map(|e| (e, FillStrategy::Null)) + .collect(); + + Ok(LogicalPlan::Extension(Extension { + node: Arc::new( + GapFill::try_new( + Arc::new(new_aggr_plan), + new_group_expr, + aggr_expr, + GapFillParams { + stride, + time_column, + origin, + time_range, + fill_strategy: fill_behavior, + }, + ) + .map_err(|e| e.context("GapFill::try_new"))?, + ), + })) +} + +fn validate_time_range(range: &Range>) -> Result<()> { + let Range { ref start, ref end } = range; + let (start, end) = match (start, end) { + (Bound::Unbounded, Bound::Unbounded) => { + return Err(DataFusionError::Plan( + "gap-filling query is missing both upper and lower time bounds".to_string(), + )) + } + (Bound::Unbounded, _) => Err(DataFusionError::Plan( + "gap-filling query is missing lower time bound".to_string(), + )), + (_, Bound::Unbounded) => Err(DataFusionError::Plan( + "gap-filling query is missing upper time bound".to_string(), + )), + ( + Bound::Included(start) | Bound::Excluded(start), + Bound::Included(end) | Bound::Excluded(end), + ) => Ok((start, end)), + }?; + validate_scalar_expr("lower time bound", start)?; + validate_scalar_expr("upper time bound", end) +} + +fn validate_scalar_expr(what: &str, e: &Expr) -> Result<()> { + let mut cols = HashSet::new(); + expr_to_columns(e, &mut cols)?; + if !cols.is_empty() { + Err(DataFusionError::Plan(format!( + "{what} for gap fill query must evaluate to a scalar" + ))) + } else { + Ok(()) + } +} + +struct RewriteInfo { + // Group expressions with DATE_BIN_GAPFILL rewritten to DATE_BIN. + new_group_expr: Vec, + // The index of the group expression that contained the call to DATE_BIN_GAPFILL. + date_bin_gapfill_index: usize, + // The arguments to the call to DATE_BIN_GAPFILL. + date_bin_gapfill_args: Vec, +} + +// Iterate over the group expression list. +// If it finds no occurrences of date_bin_gapfill, it will return None. +// If it finds more than one occurrence it will return an error. +// Otherwise it will return a RewriteInfo for the optimizer rule to use. +fn replace_date_bin_gapfill(group_expr: &[Expr]) -> Result> { + let mut date_bin_gapfill_count = 0; + let mut dbg_idx = None; + group_expr + .iter() + .enumerate() + .try_for_each(|(i, e)| -> Result<()> { + let fn_cnt = count_udf(e, DATE_BIN_GAPFILL_UDF_NAME)?; + date_bin_gapfill_count += fn_cnt; + if fn_cnt > 0 { + dbg_idx = Some(i); + } + Ok(()) + })?; + match date_bin_gapfill_count { + 0 => return Ok(None), + 1 => { + // Make sure that the call to DATE_BIN_GAPFILL is root expression + // excluding aliases. + let dbg_idx = dbg_idx.expect("should have found exactly one call"); + if !matches_udf( + unwrap_alias(&group_expr[dbg_idx]), + DATE_BIN_GAPFILL_UDF_NAME, + ) { + return Err(DataFusionError::Plan( + "DATE_BIN_GAPFILL must be a top-level expression in the GROUP BY clause when gap filling. It cannot be part of another expression or cast".to_string(), + )); + } + } + _ => { + return Err(DataFusionError::Plan( + "DATE_BIN_GAPFILL specified more than once".to_string(), + )) + } + } + + let date_bin_gapfill_index = dbg_idx.expect("should be found exactly one call"); + + let mut rewriter = DateBinGapfillRewriter { args: None }; + let group_expr = group_expr + .iter() + .enumerate() + .map(|(i, e)| { + if i == date_bin_gapfill_index { + e.clone().rewrite(&mut rewriter) + } else { + Ok(e.clone()) + } + }) + .collect::>>()?; + let date_bin_gapfill_args = rewriter.args.expect("should have found args"); + + Ok(Some(RewriteInfo { + new_group_expr: group_expr, + date_bin_gapfill_index, + date_bin_gapfill_args, + })) +} + +fn unwrap_alias(mut e: &Expr) -> &Expr { + loop { + match e { + Expr::Alias(Alias { expr, .. }) => e = expr.as_ref(), + e => break e, + } + } +} + +struct DateBinGapfillRewriter { + args: Option>, +} + +impl TreeNodeRewriter for DateBinGapfillRewriter { + type N = Expr; + fn pre_visit(&mut self, expr: &Expr) -> Result { + match expr { + Expr::ScalarFunction(fun) if fun.func_def.name() == DATE_BIN_GAPFILL_UDF_NAME => { + Ok(RewriteRecursion::Mutate) + } + _ => Ok(RewriteRecursion::Continue), + } + } + + fn mutate(&mut self, expr: Expr) -> Result { + // We need to preserve the name of the original expression + // so that everything stays wired up. + let orig_name = expr.display_name()?; + match expr { + Expr::ScalarFunction(ScalarFunction { func_def, args }) + if func_def.name() == DATE_BIN_GAPFILL_UDF_NAME => + { + self.args = Some(args.clone()); + Ok(Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::DateBin), + args, + }) + .alias(orig_name)) + } + _ => Ok(expr), + } + } +} + +fn udf_to_fill_strategy(name: &str) -> Option { + match name { + LOCF_UDF_NAME => Some(FillStrategy::PrevNullAsMissing), + INTERPOLATE_UDF_NAME => Some(FillStrategy::LinearInterpolate), + _ => None, + } +} + +fn fill_strategy_to_udf(fs: &FillStrategy) -> Result<&'static str> { + match fs { + FillStrategy::PrevNullAsMissing => Ok(LOCF_UDF_NAME), + FillStrategy::LinearInterpolate => Ok(INTERPOLATE_UDF_NAME), + _ => Err(DataFusionError::Internal(format!( + "unknown UDF for fill strategy {fs:?}" + ))), + } +} + +fn handle_projection(proj: &Projection) -> Result> { + let Projection { + input, + expr: proj_exprs, + schema: proj_schema, + .. + } = proj; + let Some(child_gapfill) = (match input.as_ref() { + LogicalPlan::Extension(Extension { node }) => node.as_any().downcast_ref::(), + _ => None, + }) else { + // If this is not a projection that is a parent to a GapFill node, + // then there is nothing to do. + return Ok(None); + }; + + let mut fill_fn_rewriter = FillFnRewriter { + aggr_col_fill_map: HashMap::new(), + }; + let new_proj_exprs = proj_exprs + .iter() + .map(|expr| { + expr.clone() + .rewrite(&mut fill_fn_rewriter) + .map_err(|e| e.context(format!("rewrite: {expr}"))) + }) + .collect::>>()?; + + let FillFnRewriter { aggr_col_fill_map } = fill_fn_rewriter; + if aggr_col_fill_map.is_empty() { + return Ok(None); + } + + // Clone the existing GapFill node, then modify it in place + // to reflect the new fill strategy. + let mut new_gapfill = child_gapfill.clone(); + for (e, fs) in aggr_col_fill_map { + let udf = fill_strategy_to_udf(&fs).map_err(|e| e.context("fill_strategy_to_udf"))?; + if new_gapfill.replace_fill_strategy(&e, fs).is_none() { + // There was a gap filling function called on a non-aggregate column. + return Err(DataFusionError::Plan(format!( + "{udf} must be called on an aggregate column in a gap-filling query", + ))); + } + } + + let new_proj = { + let mut proj = proj.clone(); + proj.expr = new_proj_exprs; + proj.input = Arc::new(LogicalPlan::Extension(Extension { + node: Arc::new(new_gapfill), + })); + proj.schema = Arc::clone(proj_schema); + LogicalPlan::Projection(proj) + }; + + Ok(Some(new_proj)) +} + +/// Implements `TreeNodeRewriter`: +/// - Traverses over the expressions in a projection node +/// - If it finds `locf(col)` or `interpolate(col)`, +/// it replaces them with `col AS ` +/// - Collects into [`Self::aggr_col_fill_map`] which correlates +/// aggregate columns to their [`FillStrategy`]. +struct FillFnRewriter { + aggr_col_fill_map: HashMap, +} + +impl TreeNodeRewriter for FillFnRewriter { + type N = Expr; + fn pre_visit(&mut self, expr: &Expr) -> Result { + match expr { + Expr::ScalarFunction(fun) if udf_to_fill_strategy(fun.func_def.name()).is_some() => { + Ok(RewriteRecursion::Mutate) + } + _ => Ok(RewriteRecursion::Continue), + } + } + + fn mutate(&mut self, expr: Expr) -> Result { + let orig_name = expr.display_name()?; + match expr { + Expr::ScalarFunction(ref fun) + if udf_to_fill_strategy(fun.func_def.name()).is_none() => + { + Ok(expr) + } + Expr::ScalarFunction(mut fun) => { + let fs = udf_to_fill_strategy(fun.func_def.name()).expect("must be a fill fn"); + let arg = fun.args.remove(0); + self.add_fill_strategy(arg.clone(), fs)?; + Ok(arg.alias(orig_name)) + } + _ => Ok(expr), + } + } +} + +impl FillFnRewriter { + fn add_fill_strategy(&mut self, e: Expr, fs: FillStrategy) -> Result<()> { + match self.aggr_col_fill_map.entry(e) { + hash_map::Entry::Occupied(_) => Err(DataFusionError::NotImplemented( + "multiple fill strategies for the same column".to_string(), + )), + hash_map::Entry::Vacant(ve) => { + ve.insert(fs); + Ok(()) + } + } + } +} + +fn count_udf(e: &Expr, name: &str) -> Result { + let mut count = 0; + e.apply(&mut |expr| { + if matches_udf(expr, name) { + count += 1; + } + Ok(VisitRecursion::Continue) + })?; + Ok(count) +} + +fn matches_udf(e: &Expr, name: &str) -> bool { + matches!( + e, + Expr::ScalarFunction(fun) if fun.func_def.name() == name + ) +} + +fn check_node(node: &LogicalPlan) -> Result<()> { + node.expressions().iter().try_for_each(|expr| { + let dbg_count = count_udf(expr, DATE_BIN_GAPFILL_UDF_NAME)?; + if dbg_count > 0 { + return Err(DataFusionError::Plan(format!( + "{DATE_BIN_GAPFILL_UDF_NAME} may only be used as a GROUP BY expression" + ))); + } + + for fn_name in [LOCF_UDF_NAME, INTERPOLATE_UDF_NAME] { + if count_udf(expr, fn_name)? > 0 { + return Err(DataFusionError::Plan(format!( + "{fn_name} may only be used in the SELECT list of a gap-filling query" + ))); + } + } + Ok(()) + }) +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use super::HandleGapFill; + + use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; + use datafusion::error::Result; + use datafusion::logical_expr::builder::table_scan_with_filters; + use datafusion::logical_expr::{logical_plan, LogicalPlan, LogicalPlanBuilder}; + use datafusion::optimizer::optimizer::Optimizer; + use datafusion::optimizer::OptimizerContext; + use datafusion::prelude::{avg, case, col, lit, min, Expr}; + use datafusion::scalar::ScalarValue; + use datafusion_util::lit_timestamptz_nano; + use query_functions::gapfill::{ + DATE_BIN_GAPFILL_UDF_NAME, INTERPOLATE_UDF_NAME, LOCF_UDF_NAME, + }; + + fn schema() -> Schema { + Schema::new(vec![ + Field::new( + "time", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new( + "time2", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new("loc", DataType::Utf8, false), + Field::new("temp", DataType::Float64, false), + ]) + } + + fn table_scan() -> Result { + logical_plan::table_scan(Some("temps"), &schema(), None)?.build() + } + + fn date_bin_gapfill(interval: Expr, time: Expr) -> Result { + date_bin_gapfill_with_origin(interval, time, None) + } + + fn date_bin_gapfill_with_origin( + interval: Expr, + time: Expr, + origin: Option, + ) -> Result { + let mut args = vec![interval, time]; + if let Some(origin) = origin { + args.push(origin) + } + + Ok(query_functions::registry() + .udf(DATE_BIN_GAPFILL_UDF_NAME)? + .call(args)) + } + + fn locf(arg: Expr) -> Result { + Ok(query_functions::registry() + .udf(LOCF_UDF_NAME)? + .call(vec![arg])) + } + + fn interpolate(arg: Expr) -> Result { + Ok(query_functions::registry() + .udf(INTERPOLATE_UDF_NAME)? + .call(vec![arg])) + } + + fn optimize(plan: &LogicalPlan) -> Result> { + let optimizer = Optimizer::with_rules(vec![Arc::new(HandleGapFill)]); + optimizer.optimize_recursively(&optimizer.rules[0], plan, &OptimizerContext::new()) + } + + fn assert_optimizer_err(plan: &LogicalPlan, expected: &str) { + match optimize(plan) { + Ok(plan) => assert_eq!(format!("{}", plan.unwrap().display_indent()), "an error"), + Err(ref e) => { + let actual = e.to_string(); + if expected.is_empty() || !actual.contains(expected) { + assert_eq!(actual, expected) + } + } + } + } + + fn assert_optimization_skipped(plan: &LogicalPlan) -> Result<()> { + let new_plan = optimize(plan)?; + if new_plan.is_none() { + return Ok(()); + } + assert_eq!( + format!("{}", plan.display_indent()), + format!("{}", new_plan.unwrap().display_indent()) + ); + Ok(()) + } + + fn format_optimized_plan(plan: &LogicalPlan) -> Result> { + let plan = optimize(plan)? + .expect("plan should have been optimized") + .display_indent() + .to_string(); + Ok(plan.split('\n').map(|s| s.to_string()).collect()) + } + + #[test] + fn misplaced_dbg_err() -> Result<()> { + // date_bin_gapfill used in a filter should produce an error + let scan = table_scan()?; + let plan = LogicalPlanBuilder::from(scan) + .filter( + date_bin_gapfill( + lit(ScalarValue::IntervalDayTime(Some(600_000))), + col("temp"), + )? + .gt(lit(100.0)), + )? + .build()?; + assert_optimizer_err( + &plan, + "Error during planning: date_bin_gapfill may only be used as a GROUP BY expression", + ); + Ok(()) + } + + /// calling LOCF in a WHERE predicate is not valid + #[test] + fn misplaced_locf_err() -> Result<()> { + // date_bin_gapfill used in a filter should produce an error + let scan = table_scan()?; + let plan = LogicalPlanBuilder::from(scan) + .filter(locf(col("temp"))?.gt(lit(100.0)))? + .build()?; + assert_optimizer_err( + &plan, + "Error during planning: locf may only be used in the SELECT list of a gap-filling query", + ); + Ok(()) + } + + /// calling INTERPOLATE in a WHERE predicate is not valid + #[test] + fn misplaced_interpolate_err() -> Result<()> { + // date_bin_gapfill used in a filter should produce an error + let scan = table_scan()?; + let plan = LogicalPlanBuilder::from(scan) + .filter(interpolate(col("temp"))?.gt(lit(100.0)))? + .build()?; + assert_optimizer_err( + &plan, + "Error during planning: interpolate may only be used in the SELECT list of a gap-filling query", + ); + Ok(()) + } + /// calling LOCF on the SELECT list but not on an aggregate column is not valid. + #[test] + fn misplaced_locf_non_agg_err() -> Result<()> { + let plan = LogicalPlanBuilder::from(table_scan()?) + .filter( + col("time") + .gt_eq(lit_timestamptz_nano(1000)) + .and(col("time").lt(lit_timestamptz_nano(2000))), + )? + .aggregate( + vec![ + col("loc"), + date_bin_gapfill(lit(ScalarValue::IntervalDayTime(Some(60_000))), col("time"))?, + ], + vec![avg(col("temp")), min(col("temp"))], + )? + .project(vec![ + locf(col("loc"))?, + col("date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time)"), + locf(col("AVG(temps.temp)"))?, + locf(col("MIN(temps.temp)"))?, + ])? + .build()?; + assert_optimizer_err( + &plan, + "locf must be called on an aggregate column in a gap-filling query", + ); + Ok(()) + } + + #[test] + fn different_fill_strategies_one_col() -> Result<()> { + let plan = LogicalPlanBuilder::from(table_scan()?) + .filter( + col("time") + .gt_eq(lit_timestamptz_nano(1000)) + .and(col("time").lt(lit_timestamptz_nano(2000))), + )? + .aggregate( + vec![ + col("loc"), + date_bin_gapfill(lit(ScalarValue::IntervalDayTime(Some(60_000))), col("time"))?, + ], + vec![avg(col("temp")), min(col("temp"))], + )? + .project(vec![ + locf(col("loc"))?, + col("date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time)"), + locf(col("AVG(temps.temp)"))?, + interpolate(col("AVG(temps.temp)"))?, + ])? + .build()?; + assert_optimizer_err( + &plan, + "This feature is not implemented: multiple fill strategies for the same column", + ); + Ok(()) + } + + #[test] + fn nonscalar_origin() -> Result<()> { + let plan = LogicalPlanBuilder::from(table_scan()?) + .filter( + col("time") + .gt_eq(lit_timestamptz_nano(1000)) + .and(col("time").lt(lit_timestamptz_nano(2000))), + )? + .aggregate( + vec![date_bin_gapfill_with_origin( + lit(ScalarValue::IntervalDayTime(Some(60_000))), + col("time"), + Some(col("time2")), + )?], + vec![avg(col("temp"))], + )? + .build()?; + assert_optimizer_err( + &plan, + "Error during planning: origin argument to DATE_BIN_GAPFILL for gap fill query must evaluate to a scalar", + ); + Ok(()) + } + + #[test] + fn nonscalar_stride() -> Result<()> { + let stride = case(col("loc")) + .when( + lit("kitchen"), + lit(ScalarValue::IntervalDayTime(Some(60_000))), + ) + .otherwise(lit(ScalarValue::IntervalDayTime(Some(30_000)))) + .unwrap(); + + let plan = LogicalPlanBuilder::from(table_scan()?) + .filter( + col("time") + .gt_eq(lit_timestamptz_nano(1000)) + .and(col("time").lt(lit_timestamptz_nano(2000))), + )? + .aggregate( + vec![date_bin_gapfill(stride, col("time"))?], + vec![avg(col("temp"))], + )? + .build()?; + assert_optimizer_err( + &plan, + "Error during planning: stride argument to DATE_BIN_GAPFILL for gap fill query must evaluate to a scalar", + ); + Ok(()) + } + + #[test] + fn time_range_errs() -> Result<()> { + let cases = vec![ + ( + lit(true), + "Error during planning: gap-filling query is missing both upper and lower time bounds", + ), + ( + col("time").gt_eq(lit_timestamptz_nano(1000)), + "Error during planning: gap-filling query is missing upper time bound", + ), + ( + col("time").lt(lit_timestamptz_nano(2000)), + "Error during planning: gap-filling query is missing lower time bound", + ), + ( + col("time").gt_eq(col("time2")).and( + col("time").lt(lit_timestamptz_nano(2000))), + "Error during planning: lower time bound for gap fill query must evaluate to a scalar", + ), + ( + col("time").gt_eq(lit_timestamptz_nano(2000)).and( + col("time").lt(col("time2"))), + "Error during planning: upper time bound for gap fill query must evaluate to a scalar", + ) + ]; + for c in cases { + let plan = LogicalPlanBuilder::from(table_scan()?) + .filter(c.0)? + .aggregate( + vec![date_bin_gapfill( + lit(ScalarValue::IntervalDayTime(Some(60_000))), + col("time"), + )?], + vec![avg(col("temp"))], + )? + .build()?; + assert_optimizer_err(&plan, c.1); + } + Ok(()) + } + + #[test] + fn no_change() -> Result<()> { + let plan = LogicalPlanBuilder::from(table_scan()?) + .aggregate(vec![col("loc")], vec![avg(col("temp"))])? + .build()?; + assert_optimization_skipped(&plan)?; + Ok(()) + } + + #[test] + fn date_bin_gapfill_simple() -> Result<()> { + let plan = LogicalPlanBuilder::from(table_scan()?) + .filter( + col("time") + .gt_eq(lit_timestamptz_nano(1000)) + .and(col("time").lt(lit_timestamptz_nano(2000))), + )? + .aggregate( + vec![date_bin_gapfill( + lit(ScalarValue::IntervalDayTime(Some(60_000))), + col("time"), + )?], + vec![avg(col("temp"))], + )? + .build()?; + + insta::assert_yaml_snapshot!( + format_optimized_plan(&plan)?, + @r###" + --- + - "GapFill: groupBy=[date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time)], aggr=[[AVG(temps.temp)]], time_column=date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time), stride=IntervalDayTime(\"60000\"), range=Included(Literal(TimestampNanosecond(1000, None)))..Excluded(Literal(TimestampNanosecond(2000, None)))" + - " Aggregate: groupBy=[[date_bin(IntervalDayTime(\"60000\"), temps.time) AS date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time)]], aggr=[[AVG(temps.temp)]]" + - " Filter: temps.time >= TimestampNanosecond(1000, None) AND temps.time < TimestampNanosecond(2000, None)" + - " TableScan: temps" + "###); + Ok(()) + } + + #[test] + fn date_bin_gapfill_origin() -> Result<()> { + let plan = LogicalPlanBuilder::from(table_scan()?) + .filter( + col("time") + .gt_eq(lit_timestamptz_nano(1000)) + .and(col("time").lt(lit_timestamptz_nano(2000))), + )? + .aggregate( + vec![date_bin_gapfill_with_origin( + lit(ScalarValue::IntervalDayTime(Some(60_000))), + col("time"), + Some(lit_timestamptz_nano(7)), + )?], + vec![avg(col("temp"))], + )? + .build()?; + + insta::assert_yaml_snapshot!( + format_optimized_plan(&plan)?, + @r###" + --- + - "GapFill: groupBy=[date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time,TimestampNanosecond(7, None))], aggr=[[AVG(temps.temp)]], time_column=date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time,TimestampNanosecond(7, None)), stride=IntervalDayTime(\"60000\"), range=Included(Literal(TimestampNanosecond(1000, None)))..Excluded(Literal(TimestampNanosecond(2000, None)))" + - " Aggregate: groupBy=[[date_bin(IntervalDayTime(\"60000\"), temps.time, TimestampNanosecond(7, None)) AS date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time,TimestampNanosecond(7, None))]], aggr=[[AVG(temps.temp)]]" + - " Filter: temps.time >= TimestampNanosecond(1000, None) AND temps.time < TimestampNanosecond(2000, None)" + - " TableScan: temps" + "###); + Ok(()) + } + #[test] + fn two_group_exprs() -> Result<()> { + // grouping by date_bin_gapfill(...), loc + let plan = LogicalPlanBuilder::from(table_scan()?) + .filter( + col("time") + .gt_eq(lit_timestamptz_nano(1000)) + .and(col("time").lt(lit_timestamptz_nano(2000))), + )? + .aggregate( + vec![ + date_bin_gapfill(lit(ScalarValue::IntervalDayTime(Some(60_000))), col("time"))?, + col("loc"), + ], + vec![avg(col("temp"))], + )? + .build()?; + + insta::assert_yaml_snapshot!( + format_optimized_plan(&plan)?, + @r###" + --- + - "GapFill: groupBy=[date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time), temps.loc], aggr=[[AVG(temps.temp)]], time_column=date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time), stride=IntervalDayTime(\"60000\"), range=Included(Literal(TimestampNanosecond(1000, None)))..Excluded(Literal(TimestampNanosecond(2000, None)))" + - " Aggregate: groupBy=[[date_bin(IntervalDayTime(\"60000\"), temps.time) AS date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time), temps.loc]], aggr=[[AVG(temps.temp)]]" + - " Filter: temps.time >= TimestampNanosecond(1000, None) AND temps.time < TimestampNanosecond(2000, None)" + - " TableScan: temps" + "###); + Ok(()) + } + + #[test] + fn double_date_bin_gapfill() -> Result<()> { + let plan = LogicalPlanBuilder::from(table_scan()?) + .aggregate( + vec![ + date_bin_gapfill(lit(ScalarValue::IntervalDayTime(Some(60_000))), col("time"))?, + date_bin_gapfill(lit(ScalarValue::IntervalDayTime(Some(30_000))), col("time"))?, + ], + vec![avg(col("temp"))], + )? + .build()?; + assert_optimizer_err( + &plan, + "Error during planning: DATE_BIN_GAPFILL specified more than once", + ); + Ok(()) + } + + #[test] + fn with_projection() -> Result<()> { + let plan = LogicalPlanBuilder::from(table_scan()?) + .filter( + col("time") + .gt_eq(lit_timestamptz_nano(1000)) + .and(col("time").lt(lit_timestamptz_nano(2000))), + )? + .aggregate( + vec![date_bin_gapfill( + lit(ScalarValue::IntervalDayTime(Some(60_000))), + col("time"), + )?], + vec![avg(col("temp"))], + )? + .project(vec![ + col("date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time)"), + col("AVG(temps.temp)"), + ])? + .build()?; + + insta::assert_yaml_snapshot!( + format_optimized_plan(&plan)?, + @r###" + --- + - "Projection: date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time), AVG(temps.temp)" + - " GapFill: groupBy=[date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time)], aggr=[[AVG(temps.temp)]], time_column=date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time), stride=IntervalDayTime(\"60000\"), range=Included(Literal(TimestampNanosecond(1000, None)))..Excluded(Literal(TimestampNanosecond(2000, None)))" + - " Aggregate: groupBy=[[date_bin(IntervalDayTime(\"60000\"), temps.time) AS date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time)]], aggr=[[AVG(temps.temp)]]" + - " Filter: temps.time >= TimestampNanosecond(1000, None) AND temps.time < TimestampNanosecond(2000, None)" + - " TableScan: temps" + "###); + Ok(()) + } + + #[test] + fn with_locf() -> Result<()> { + let plan = LogicalPlanBuilder::from(table_scan()?) + .filter( + col("time") + .gt_eq(lit_timestamptz_nano(1000)) + .and(col("time").lt(lit_timestamptz_nano(2000))), + )? + .aggregate( + vec![date_bin_gapfill( + lit(ScalarValue::IntervalDayTime(Some(60_000))), + col("time"), + )?], + vec![avg(col("temp")), min(col("temp"))], + )? + .project(vec![ + col("date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time)"), + locf(col("AVG(temps.temp)"))?, + locf(col("MIN(temps.temp)"))?, + ])? + .build()?; + + insta::assert_yaml_snapshot!( + format_optimized_plan(&plan)?, + @r###" + --- + - "Projection: date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time), AVG(temps.temp) AS locf(AVG(temps.temp)), MIN(temps.temp) AS locf(MIN(temps.temp))" + - " GapFill: groupBy=[date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time)], aggr=[[LOCF(AVG(temps.temp)), LOCF(MIN(temps.temp))]], time_column=date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time), stride=IntervalDayTime(\"60000\"), range=Included(Literal(TimestampNanosecond(1000, None)))..Excluded(Literal(TimestampNanosecond(2000, None)))" + - " Aggregate: groupBy=[[date_bin(IntervalDayTime(\"60000\"), temps.time) AS date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time)]], aggr=[[AVG(temps.temp), MIN(temps.temp)]]" + - " Filter: temps.time >= TimestampNanosecond(1000, None) AND temps.time < TimestampNanosecond(2000, None)" + - " TableScan: temps" + "###); + Ok(()) + } + + #[test] + fn with_locf_aliased() -> Result<()> { + let plan = LogicalPlanBuilder::from(table_scan()?) + .filter( + col("time") + .gt_eq(lit_timestamptz_nano(1000)) + .and(col("time").lt(lit_timestamptz_nano(2000))), + )? + .aggregate( + vec![date_bin_gapfill( + lit(ScalarValue::IntervalDayTime(Some(60_000))), + col("time"), + )?], + vec![avg(col("temp")), min(col("temp"))], + )? + .project(vec![ + col("date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time)"), + locf(col("MIN(temps.temp)"))?.alias("locf_min_temp"), + ])? + .build()?; + + insta::assert_yaml_snapshot!( + format_optimized_plan(&plan)?, + @r###" + --- + - "Projection: date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time), MIN(temps.temp) AS locf(MIN(temps.temp)) AS locf_min_temp" + - " GapFill: groupBy=[date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time)], aggr=[[AVG(temps.temp), LOCF(MIN(temps.temp))]], time_column=date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time), stride=IntervalDayTime(\"60000\"), range=Included(Literal(TimestampNanosecond(1000, None)))..Excluded(Literal(TimestampNanosecond(2000, None)))" + - " Aggregate: groupBy=[[date_bin(IntervalDayTime(\"60000\"), temps.time) AS date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time)]], aggr=[[AVG(temps.temp), MIN(temps.temp)]]" + - " Filter: temps.time >= TimestampNanosecond(1000, None) AND temps.time < TimestampNanosecond(2000, None)" + - " TableScan: temps" + "###); + Ok(()) + } + + #[test] + fn with_interpolate() -> Result<()> { + let plan = LogicalPlanBuilder::from(table_scan()?) + .filter( + col("time") + .gt_eq(lit_timestamptz_nano(1000)) + .and(col("time").lt(lit_timestamptz_nano(2000))), + )? + .aggregate( + vec![date_bin_gapfill( + lit(ScalarValue::IntervalDayTime(Some(60_000))), + col("time"), + )?], + vec![avg(col("temp")), min(col("temp"))], + )? + .project(vec![ + col("date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time)"), + interpolate(col("AVG(temps.temp)"))?, + interpolate(col("MIN(temps.temp)"))?, + ])? + .build()?; + + insta::assert_yaml_snapshot!( + format_optimized_plan(&plan)?, + @r###" + --- + - "Projection: date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time), AVG(temps.temp) AS interpolate(AVG(temps.temp)), MIN(temps.temp) AS interpolate(MIN(temps.temp))" + - " GapFill: groupBy=[date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time)], aggr=[[INTERPOLATE(AVG(temps.temp)), INTERPOLATE(MIN(temps.temp))]], time_column=date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time), stride=IntervalDayTime(\"60000\"), range=Included(Literal(TimestampNanosecond(1000, None)))..Excluded(Literal(TimestampNanosecond(2000, None)))" + - " Aggregate: groupBy=[[date_bin(IntervalDayTime(\"60000\"), temps.time) AS date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time)]], aggr=[[AVG(temps.temp), MIN(temps.temp)]]" + - " Filter: temps.time >= TimestampNanosecond(1000, None) AND temps.time < TimestampNanosecond(2000, None)" + - " TableScan: temps" + "###); + Ok(()) + } + + #[test] + fn scan_filter_not_part_of_projection() { + let schema = schema(); + let plan = table_scan_with_filters( + Some("temps"), + &schema, + Some(vec![schema.index_of("time").unwrap()]), + vec![ + col("temps.time").gt_eq(lit_timestamptz_nano(1000)), + col("temps.time").lt(lit_timestamptz_nano(2000)), + col("temps.loc").eq(lit("foo")), + ], + ) + .unwrap() + .aggregate( + vec![ + date_bin_gapfill(lit(ScalarValue::IntervalDayTime(Some(60_000))), col("time")) + .unwrap(), + ], + std::iter::empty::(), + ) + .unwrap() + .build() + .unwrap(); + + insta::assert_yaml_snapshot!( + format_optimized_plan(&plan).unwrap(), + @r###" + --- + - "GapFill: groupBy=[date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time)], aggr=[[]], time_column=date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time), stride=IntervalDayTime(\"60000\"), range=Included(Literal(TimestampNanosecond(1000, None)))..Excluded(Literal(TimestampNanosecond(2000, None)))" + - " Aggregate: groupBy=[[date_bin(IntervalDayTime(\"60000\"), temps.time) AS date_bin_gapfill(IntervalDayTime(\"60000\"),temps.time)]], aggr=[[]]" + - " TableScan: temps projection=[time], full_filters=[temps.time >= TimestampNanosecond(1000, None), temps.time < TimestampNanosecond(2000, None), temps.loc = Utf8(\"foo\")]" + "###); + } +} diff --git a/iox_query/src/logical_optimizer/handle_gapfill/range_predicate.rs b/iox_query/src/logical_optimizer/handle_gapfill/range_predicate.rs new file mode 100644 index 0000000..26b9682 --- /dev/null +++ b/iox_query/src/logical_optimizer/handle_gapfill/range_predicate.rs @@ -0,0 +1,367 @@ +//! Find the time range from the filters in a logical plan. +use std::{ + ops::{Bound, Range}, + sync::Arc, +}; + +use datafusion::{ + common::{ + tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}, + DFSchema, + }, + error::Result, + logical_expr::{ + utils::split_conjunction, Between, BinaryExpr, LogicalPlan, LogicalPlanBuilder, Operator, + }, + prelude::{Column, Expr}, +}; + +use super::unwrap_alias; + +/// Given a plan and a column, finds the predicates that use that column +/// and return a range with expressions for upper and lower bounds. +pub fn find_time_range(plan: &LogicalPlan, time_col: &Column) -> Result>> { + let mut v = TimeRangeVisitor { + col: time_col.clone(), + range: TimeRange::default(), + }; + plan.visit(&mut v)?; + Ok(v.range.0) +} + +struct TimeRangeVisitor { + col: Column, + range: TimeRange, +} + +impl TreeNodeVisitor for TimeRangeVisitor { + type N = LogicalPlan; + + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + match plan { + LogicalPlan::Projection(p) => { + let idx = p.schema.index_of_column(&self.col)?; + match unwrap_alias(&p.expr[idx]) { + Expr::Column(ref c) => { + self.col = c.clone(); + Ok(VisitRecursion::Continue) + } + _ => Ok(VisitRecursion::Stop), + } + } + LogicalPlan::Filter(f) => { + let range = self.range.clone(); + let range = split_conjunction(&f.predicate) + .iter() + .try_fold(range, |range, expr| { + range.with_expr(f.input.schema().as_ref(), &self.col, expr) + })?; + self.range = range; + Ok(VisitRecursion::Continue) + } + LogicalPlan::TableScan(t) => { + let range = self.range.clone(); + + // filters may use columns that are NOT part of a projection, so we need the underlying schema. Because + // that's a bit of a mess in DF, we reconstruct the schema using the plan builder. + let unprojected_scan = LogicalPlanBuilder::scan_with_filters( + t.table_name.to_owned(), + Arc::clone(&t.source), + None, + t.filters.clone(), + ) + .map_err(|e| e.context("reconstruct unprojected scheam"))?; + let unprojected_schema = unprojected_scan.schema(); + let range = t + .filters + .iter() + .flat_map(split_conjunction) + .try_fold(range, |range, expr| { + range.with_expr(unprojected_schema, &self.col, expr) + })?; + self.range = range; + Ok(VisitRecursion::Continue) + } + LogicalPlan::SubqueryAlias(_) => { + // The nodes below this one refer to the column with a different table name, + // just unset the relation so we match on the column name. + self.col.relation = None; + Ok(VisitRecursion::Continue) + } + // These nodes do not alter their schema, so we can recurse through them + LogicalPlan::Sort(_) + | LogicalPlan::Repartition(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Distinct(_) => Ok(VisitRecursion::Continue), + // At some point we may wish to handle joins here too. + _ => Ok(VisitRecursion::Stop), + } + } +} + +/// Encapsulates the upper and lower bounds of a time column +/// in a logical plan. +#[derive(Clone)] +struct TimeRange(pub Range>); + +impl Default for TimeRange { + fn default() -> Self { + Self(Range { + start: Bound::Unbounded, + end: Bound::Unbounded, + }) + } +} + +impl TimeRange { + // If the given expression uses the given column with comparison operators, update + // this time range to reflect that. + fn with_expr(self, schema: &DFSchema, time_col: &Column, expr: &Expr) -> Result { + let is_time_col = |e| -> Result { + match Expr::try_into_col(e) { + Ok(col) => Ok(schema.index_of_column(&col)? == schema.index_of_column(time_col)?), + Err(_) => Ok(false), + } + }; + + Ok(match expr { + Expr::BinaryExpr(BinaryExpr { left, op, right }) if is_time_col(left)? => match op { + Operator::Lt => self.with_upper(Bound::Excluded(*right.clone())), + Operator::LtEq => self.with_upper(Bound::Included(*right.clone())), + Operator::Gt => self.with_lower(Bound::Excluded(*right.clone())), + Operator::GtEq => self.with_lower(Bound::Included(*right.clone())), + _ => self, + }, + Expr::BinaryExpr(BinaryExpr { left, op, right }) if is_time_col(right)? => match op { + Operator::Lt => self.with_lower(Bound::Excluded(*left.clone())), + Operator::LtEq => self.with_lower(Bound::Included(*left.clone())), + Operator::Gt => self.with_upper(Bound::Excluded(*left.clone())), + Operator::GtEq => self.with_upper(Bound::Included(*left.clone())), + _ => self, + }, + // Between bounds are inclusive + Expr::Between(Between { + expr, + negated: false, + low, + high, + }) if is_time_col(expr)? => self + .with_lower(Bound::Included(*low.clone())) + .with_upper(Bound::Included(*high.clone())), + _ => self, + }) + } + + fn with_lower(self, start: Bound) -> Self { + Self(Range { + start, + end: self.0.end, + }) + } + + fn with_upper(self, end: Bound) -> Self { + Self(Range { + start: self.0.start, + end, + }) + } +} + +#[cfg(test)] +mod tests { + use std::{ + ops::{Bound, Range}, + sync::Arc, + }; + + use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; + use datafusion::{ + error::Result, + logical_expr::{ + logical_plan::{self, builder::LogicalTableSource}, + Between, LogicalPlan, LogicalPlanBuilder, + }, + prelude::{col, lit, Column, Expr, Partitioning}, + sql::TableReference, + }; + use datafusion_util::lit_timestamptz_nano; + + use super::find_time_range; + + fn schema() -> Schema { + Schema::new(vec![ + Field::new( + "time", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new("temp", DataType::Float64, false), + ]) + } + + fn table_scan() -> Result { + let schema = schema(); + logical_plan::table_scan(Some("t"), &schema, None)?.build() + } + + fn simple_filter_plan(pred: Expr, inline_filter: bool) -> Result { + let schema = schema(); + let table_source = Arc::new(LogicalTableSource::new(Arc::new(schema))); + let name = TableReference::from("t").to_quoted_string(); + if inline_filter { + LogicalPlanBuilder::scan_with_filters(name, table_source, None, vec![pred])?.build() + } else { + LogicalPlanBuilder::scan(name, table_source, None)? + .filter(pred)? + .build() + } + } + + fn between(expr: Expr, low: Expr, high: Expr) -> Expr { + Expr::Between(Between { + expr: Box::new(expr), + negated: false, + low: Box::new(low), + high: Box::new(high), + }) + } + + #[test] + fn test_find_range() -> Result<()> { + let time_col = Column::from_name("time"); + + let cases = vec![ + ( + "unbounded", + lit(true), + Range { + start: Bound::Unbounded, + end: Bound::Unbounded, + }, + ), + ( + "time_gt_val", + col("time").gt(lit_timestamptz_nano(1000)), + Range { + start: Bound::Excluded(lit_timestamptz_nano(1000)), + end: Bound::Unbounded, + }, + ), + ( + "time_gt_eq_val", + col("time").gt_eq(lit_timestamptz_nano(1000)), + Range { + start: Bound::Included(lit_timestamptz_nano(1000)), + end: Bound::Unbounded, + }, + ), + ( + "time_lt_val", + col("time").lt(lit_timestamptz_nano(1000)), + Range { + start: Bound::Unbounded, + end: Bound::Excluded(lit_timestamptz_nano(1000)), + }, + ), + ( + "time_lt_eq_val", + col("time").lt_eq(lit_timestamptz_nano(1000)), + Range { + start: Bound::Unbounded, + end: Bound::Included(lit_timestamptz_nano(1000)), + }, + ), + ( + "val_gt_time", + lit_timestamptz_nano(1000).gt(col("time")), + Range { + start: Bound::Unbounded, + end: Bound::Excluded(lit_timestamptz_nano(1000)), + }, + ), + ( + "val_gt_eq_time", + lit_timestamptz_nano(1000).gt_eq(col("time")), + Range { + start: Bound::Unbounded, + end: Bound::Included(lit_timestamptz_nano(1000)), + }, + ), + ( + "val_lt_time", + lit_timestamptz_nano(1000).lt(col("time")), + Range { + start: Bound::Excluded(lit_timestamptz_nano(1000)), + end: Bound::Unbounded, + }, + ), + ( + "val_lt_eq_time", + lit_timestamptz_nano(1000).lt_eq(col("time")), + Range { + start: Bound::Included(lit_timestamptz_nano(1000)), + end: Bound::Unbounded, + }, + ), + ( + "and", + col("time") + .gt_eq(lit_timestamptz_nano(1000)) + .and(col("time").lt(lit_timestamptz_nano(2000))), + Range { + start: Bound::Included(lit_timestamptz_nano(1000)), + end: Bound::Excluded(lit_timestamptz_nano(2000)), + }, + ), + ( + "between", + between( + col("time"), + lit_timestamptz_nano(1000), + lit_timestamptz_nano(2000), + ), + Range { + start: Bound::Included(lit_timestamptz_nano(1000)), + end: Bound::Included(lit_timestamptz_nano(2000)), + }, + ), + ]; + for (name, pred, expected) in cases { + for inline_filter in [false, true] { + let plan = simple_filter_plan(pred.clone(), inline_filter)?; + let actual = find_time_range(&plan, &time_col)?; + assert_eq!( + expected, actual, + "test case `{name}` with inline_filter={inline_filter} failed", + ); + } + } + Ok(()) + } + + #[test] + fn plan_traversal() -> Result<()> { + // Show that the time range can be found + // - through nodes that don't alter their schema + // - even when predicates are in different filter nodes + // - through projections that alias columns + let plan = LogicalPlanBuilder::from(table_scan()?) + .filter(col("time").gt_eq(lit_timestamptz_nano(1000)))? + .sort(vec![col("time")])? + .limit(0, Some(10))? + .project(vec![col("time").alias("other_time")])? + .filter(col("other_time").lt(lit_timestamptz_nano(2000)))? + .distinct()? + .repartition(Partitioning::RoundRobinBatch(1))? + .project(vec![col("other_time").alias("my_time")])? + .build()?; + let time_col = Column::from_name("my_time"); + let actual = find_time_range(&plan, &time_col)?; + let expected = Range { + start: Bound::Included(lit_timestamptz_nano(1000)), + end: Bound::Excluded(lit_timestamptz_nano(2000)), + }; + assert_eq!(expected, actual); + Ok(()) + } +} diff --git a/iox_query/src/logical_optimizer/influx_regex_to_datafusion_regex.rs b/iox_query/src/logical_optimizer/influx_regex_to_datafusion_regex.rs new file mode 100644 index 0000000..3660cdb --- /dev/null +++ b/iox_query/src/logical_optimizer/influx_regex_to_datafusion_regex.rs @@ -0,0 +1,96 @@ +use datafusion::logical_expr::expr::ScalarFunction; +use datafusion::{ + common::{tree_node::TreeNodeRewriter, DFSchema}, + error::DataFusionError, + logical_expr::{expr_rewriter::rewrite_preserving_name, LogicalPlan, Operator}, + optimizer::{OptimizerConfig, OptimizerRule}, + prelude::{binary_expr, lit, Expr}, + scalar::ScalarValue, +}; +use query_functions::{clean_non_meta_escapes, REGEX_MATCH_UDF_NAME, REGEX_NOT_MATCH_UDF_NAME}; + +/// Replaces InfluxDB-specific regex operator with DataFusion regex operator. +/// +/// InfluxDB has a special regex operator that is especially used by Flux/InfluxQL and that excepts certain escape +/// sequences that are normal Rust regex crate does NOT support. If the pattern is already known at planning time (i.e. +/// it is a constant), then we can clean the escape sequences and just use the ordinary DataFusion regex operator. This +/// is desired because the ordinary DataFusion regex operator can be optimized further (e.g. to cheaper `LIKE` expressions). +#[derive(Debug, Clone)] +pub struct InfluxRegexToDataFusionRegex {} + +impl InfluxRegexToDataFusionRegex { + /// Create new optimizer rule. + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for InfluxRegexToDataFusionRegex { + fn name(&self) -> &str { + "influx_regex_to_datafusion_regex" + } + + fn try_optimize( + &self, + plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> datafusion::error::Result> { + optimize(plan).map(Some) + } +} + +fn optimize(plan: &LogicalPlan) -> Result { + let new_inputs = plan + .inputs() + .iter() + .map(|input| optimize(input)) + .collect::, DataFusionError>>()?; + + let mut schema = + new_inputs + .iter() + .map(|input| input.schema()) + .fold(DFSchema::empty(), |mut lhs, rhs| { + lhs.merge(rhs); + lhs + }); + + schema.merge(plan.schema()); + + let mut expr_rewriter = InfluxRegexToDataFusionRegex {}; + + let new_exprs = plan + .expressions() + .into_iter() + .map(|expr| rewrite_preserving_name(expr, &mut expr_rewriter)) + .collect::, DataFusionError>>()?; + plan.with_new_exprs(new_exprs, &new_inputs) +} + +impl TreeNodeRewriter for InfluxRegexToDataFusionRegex { + type N = Expr; + + fn mutate(&mut self, expr: Expr) -> Result { + match expr { + Expr::ScalarFunction(ScalarFunction { func_def, mut args }) => { + let name = func_def.name(); + if (args.len() == 2) + && ((name == REGEX_MATCH_UDF_NAME) || (name == REGEX_NOT_MATCH_UDF_NAME)) + { + if let Expr::Literal(ScalarValue::Utf8(Some(s))) = &args[1] { + let s = clean_non_meta_escapes(s); + let op = match name { + REGEX_MATCH_UDF_NAME => Operator::RegexMatch, + REGEX_NOT_MATCH_UDF_NAME => Operator::RegexNotMatch, + _ => unreachable!(), + }; + return Ok(binary_expr(args.remove(0), op, lit(s))); + } + } + + Ok(Expr::ScalarFunction(ScalarFunction { func_def, args })) + } + _ => Ok(expr), + } + } +} diff --git a/iox_query/src/logical_optimizer/mod.rs b/iox_query/src/logical_optimizer/mod.rs new file mode 100644 index 0000000..42b72e1 --- /dev/null +++ b/iox_query/src/logical_optimizer/mod.rs @@ -0,0 +1,23 @@ +use std::sync::Arc; + +use datafusion::execution::context::SessionState; + +use self::{ + extract_sleep::ExtractSleep, handle_gapfill::HandleGapFill, + influx_regex_to_datafusion_regex::InfluxRegexToDataFusionRegex, +}; + +mod extract_sleep; +mod handle_gapfill; +mod influx_regex_to_datafusion_regex; +pub use handle_gapfill::range_predicate; + +/// Register IOx-specific logical [`OptimizerRule`]s with the SessionContext +/// +/// [`OptimizerRule`]: datafusion::optimizer::OptimizerRule +pub fn register_iox_logical_optimizers(state: SessionState) -> SessionState { + state + .add_optimizer_rule(Arc::new(InfluxRegexToDataFusionRegex::new())) + .add_optimizer_rule(Arc::new(ExtractSleep::new())) + .add_optimizer_rule(Arc::new(HandleGapFill::new())) +} diff --git a/iox_query/src/physical_optimizer/chunk_extraction.rs b/iox_query/src/physical_optimizer/chunk_extraction.rs new file mode 100644 index 0000000..488b5df --- /dev/null +++ b/iox_query/src/physical_optimizer/chunk_extraction.rs @@ -0,0 +1,367 @@ +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use datafusion::{ + datasource::physical_plan::ParquetExec, + error::DataFusionError, + physical_plan::{ + empty::EmptyExec, placeholder_row::PlaceholderRowExec, union::UnionExec, + visit_execution_plan, ExecutionPlan, ExecutionPlanVisitor, + }, +}; +use observability_deps::tracing::debug; +use schema::sort::SortKey; + +use crate::{ + provider::{PartitionedFileExt, RecordBatchesExec}, + QueryChunk, +}; + +/// List of [`QueryChunk`]s. +pub type QueryChunks = Vec>; + +/// Extract chunks, schema, and output sort key from plans created with [`chunks_to_physical_nodes`]. +/// +/// Returns `None` if no chunks (or an [`EmptyExec`] in case that no chunks where passed to +/// [`chunks_to_physical_nodes`]) were found or if the chunk data is inconsistent. +/// +/// When no chunks were passed to [`chunks_to_physical_nodes`] and hence an [`EmptyExec`] was created, then no output +/// sort key can be reconstructed. However this is usually OK because it does not have any effect anyways. +/// +/// Note that this only works on the direct output of [`chunks_to_physical_nodes`]. If the plan is wrapped into +/// additional nodes (like de-duplication, filtering, projection) then NO data will be returned. Also [`ParquetExec`] +/// MUST NOT have a predicate attached. +/// +/// +/// [`chunks_to_physical_nodes`]: crate::provider::chunks_to_physical_nodes +pub fn extract_chunks( + plan: &dyn ExecutionPlan, +) -> Option<(SchemaRef, QueryChunks, Option)> { + let mut visitor = ExtractChunksVisitor::default(); + if let Err(e) = visit_execution_plan(plan, &mut visitor) { + debug!( + %e, + "cannot extract chunks", + ); + return None; + } + visitor + .schema + .map(|schema| (schema, visitor.chunks, visitor.sort_key)) +} + +#[derive(Debug, Default)] +struct ExtractChunksVisitor { + chunks: Vec>, + schema: Option, + sort_key: Option, +} + +impl ExtractChunksVisitor { + fn add_chunk(&mut self, chunk: Arc) { + self.chunks.push(chunk); + } + + fn add_schema_from_exec(&mut self, exec: &dyn ExecutionPlan) -> Result<(), DataFusionError> { + let schema = exec.schema(); + if let Some(existing) = &self.schema { + if existing != &schema { + return Err(DataFusionError::External( + String::from("Different schema").into(), + )); + } + } else { + self.schema = Some(schema); + } + Ok(()) + } + + fn add_sort_key(&mut self, sort_key: Option<&SortKey>) -> Result<(), DataFusionError> { + let Some(sort_key) = sort_key else { + return Ok(()); + }; + + if let Some(existing) = &self.sort_key { + if existing != sort_key { + return Err(DataFusionError::External( + String::from("Different sort key").into(), + )); + } + } else { + self.sort_key = Some(sort_key.clone()); + } + + Ok(()) + } +} + +impl ExecutionPlanVisitor for ExtractChunksVisitor { + type Error = DataFusionError; + + fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> Result { + let plan_any = plan.as_any(); + + if let Some(record_batches_exec) = plan_any.downcast_ref::() { + self.add_schema_from_exec(record_batches_exec) + .map_err(|e| { + DataFusionError::Context( + "add schema from RecordBatchesExec".to_owned(), + Box::new(e), + ) + })?; + + self.add_sort_key(record_batches_exec.output_sort_key_memo())?; + + for chunk in record_batches_exec.chunks() { + self.add_chunk(Arc::clone(chunk)); + } + } else if let Some(parquet_exec) = plan_any.downcast_ref::() { + if parquet_exec.predicate().is_some() { + return Err(DataFusionError::External( + String::from("ParquetExec has predicate").into(), + )); + } + + self.add_schema_from_exec(parquet_exec).map_err(|e| { + DataFusionError::Context("add schema from ParquetExec".to_owned(), Box::new(e)) + })?; + + for group in &parquet_exec.base_config().file_groups { + for file in group { + let ext = file + .extensions + .as_ref() + .and_then(|any| any.downcast_ref::()) + .ok_or_else(|| { + DataFusionError::External( + String::from("PartitionedFileExt not found").into(), + ) + })?; + self.add_sort_key(ext.output_sort_key_memo.as_ref())?; + self.add_chunk(Arc::clone(&ext.chunk)); + } + } + } else if plan_any.downcast_ref::().is_some() { + // should not produce dummy data + return Err(DataFusionError::External( + String::from("EmptyExec produces row").into(), + )); + } else if let Some(empty_exec) = plan_any.downcast_ref::() { + self.add_schema_from_exec(empty_exec).map_err(|e| { + DataFusionError::Context("add schema from EmptyExec".to_owned(), Box::new(e)) + })?; + } else if plan_any.downcast_ref::().is_some() { + // continue visiting + } else { + // unsupported node + return Err(DataFusionError::External( + String::from("Unsupported node").into(), + )); + } + + Ok(true) + } +} + +#[cfg(test)] +mod tests { + use crate::{provider::chunks_to_physical_nodes, test::TestChunk, util::df_physical_expr}; + use arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; + use data_types::ChunkId; + use datafusion::{ + common::tree_node::{Transformed, TreeNode}, + physical_plan::{expressions::Literal, filter::FilterExec}, + prelude::{col, lit}, + scalar::ScalarValue, + }; + use schema::{merge::SchemaMerger, sort::SortKeyBuilder, SchemaBuilder, TIME_COLUMN_NAME}; + + use super::*; + + #[test] + fn test_roundtrip_empty() { + let schema = chunk(1).schema().as_arrow(); + assert_roundtrip(schema, vec![], None); + } + + #[test] + fn test_roundtrip_single_record_batch() { + let chunk1 = chunk(1); + let sort_key = Some(sort_key()); + assert_roundtrip(chunk1.schema().as_arrow(), vec![Arc::new(chunk1)], sort_key); + } + + #[test] + fn test_roundtrip_single_parquet() { + let chunk1 = chunk(1).with_dummy_parquet_file(); + let sort_key = Some(sort_key()); + assert_roundtrip(chunk1.schema().as_arrow(), vec![Arc::new(chunk1)], sort_key); + } + + #[test] + fn test_roundtrip_many_chunks() { + let chunk1 = chunk(1).with_dummy_parquet_file(); + let chunk2 = chunk(2).with_dummy_parquet_file(); + let chunk3 = chunk(3).with_dummy_parquet_file(); + let chunk4 = chunk(4); + let chunk5 = chunk(5); + let sort_key = Some(sort_key()); + assert_roundtrip( + chunk1.schema().as_arrow(), + vec![ + Arc::new(chunk1), + Arc::new(chunk2), + Arc::new(chunk3), + Arc::new(chunk4), + Arc::new(chunk5), + ], + sort_key, + ); + } + + #[test] + fn test_different_schemas() { + let some_chunk = chunk(1); + let iox_schema = some_chunk.schema(); + let schema1 = iox_schema.as_arrow(); + let schema2 = iox_schema.select_by_indices(&[]).as_arrow(); + let plan = UnionExec::new(vec![ + Arc::new(EmptyExec::new(schema1)), + Arc::new(EmptyExec::new(schema2)), + ]); + assert!(extract_chunks(&plan).is_none()); + } + + #[test] + fn test_empty_exec_with_rows() { + let schema = chunk(1).schema().as_arrow(); + let plan = PlaceholderRowExec::new(schema); + assert!(extract_chunks(&plan).is_none()); + } + + #[test] + fn test_empty_exec_no_iox_schema() { + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + "x", + DataType::Float64, + true, + )])); + let plan = EmptyExec::new(Arc::clone(&schema)); + let (schema2, chunks, sort_key) = extract_chunks(&plan).unwrap(); + assert_eq!(schema, schema2); + assert!(chunks.is_empty()); + assert!(sort_key.is_none()); + } + + #[test] + fn test_different_sort_keys() { + let sort_key1 = Arc::new(SortKeyBuilder::new().with_col("tag1").build()); + let sort_key2 = Arc::new(SortKeyBuilder::new().with_col("tag2").build()); + let chunk1 = Arc::new(chunk(1)) as Arc; + let schema = chunk1.schema().as_arrow(); + let plan = UnionExec::new(vec![ + chunks_to_physical_nodes(&schema, Some(&sort_key1), vec![Arc::clone(&chunk1)], 1), + chunks_to_physical_nodes(&schema, Some(&sort_key2), vec![chunk1], 1), + ]); + assert!(extract_chunks(&plan).is_none()); + } + + #[test] + fn test_stop_at_other_node_types() { + let chunk1 = chunk(1); + let schema = chunk1.schema().as_arrow(); + let plan = chunks_to_physical_nodes(&schema, None, vec![Arc::new(chunk1)], 2); + let plan = FilterExec::try_new( + df_physical_expr(plan.schema(), col("tag1").eq(lit("foo"))).unwrap(), + plan, + ) + .unwrap(); + assert!(extract_chunks(&plan).is_none()); + } + + #[test] + fn test_preserve_record_batches_exec_schema() { + let chunk = chunk(1); + let schema_ext = SchemaBuilder::new().tag("zzz").build().unwrap(); + let schema = SchemaMerger::new() + .merge(chunk.schema()) + .unwrap() + .merge(&schema_ext) + .unwrap() + .build() + .as_arrow(); + assert_roundtrip(schema, vec![Arc::new(chunk)], None); + } + + #[test] + fn test_preserve_parquet_exec_schema() { + let chunk = chunk(1).with_dummy_parquet_file(); + let schema_ext = SchemaBuilder::new().tag("zzz").build().unwrap(); + let schema = SchemaMerger::new() + .merge(chunk.schema()) + .unwrap() + .merge(&schema_ext) + .unwrap() + .build() + .as_arrow(); + assert_roundtrip(schema, vec![Arc::new(chunk)], None); + } + + #[test] + fn test_parquet_with_predicate_fails() { + let chunk = chunk(1).with_dummy_parquet_file(); + let schema = chunk.schema().as_arrow(); + let plan = chunks_to_physical_nodes(&schema, None, vec![Arc::new(chunk)], 2); + let plan = plan + .transform_down(&|plan| { + if let Some(exec) = plan.as_any().downcast_ref::() { + let exec = ParquetExec::new( + exec.base_config().clone(), + Some(Arc::new(Literal::new(ScalarValue::from(false)))), + None, + ); + return Ok(Transformed::Yes(Arc::new(exec))); + } + Ok(Transformed::No(plan)) + }) + .unwrap(); + assert!(extract_chunks(plan.as_ref()).is_none()); + } + + #[track_caller] + fn assert_roundtrip( + schema: SchemaRef, + chunks: Vec>, + output_sort_key: Option, + ) { + let plan = chunks_to_physical_nodes(&schema, output_sort_key.as_ref(), chunks.clone(), 2); + let (schema2, chunks2, output_sort_key2) = + extract_chunks(plan.as_ref()).expect("data found"); + assert_eq!(schema, schema2); + assert_eq!(chunk_ids(&chunks), chunk_ids(&chunks2)); + assert_eq!(output_sort_key, output_sort_key2); + } + + fn chunk_ids(chunks: &[Arc]) -> Vec { + let mut ids = chunks.iter().map(|c| c.id()).collect::>(); + ids.sort(); + ids + } + + fn chunk(id: u128) -> TestChunk { + TestChunk::new("table") + .with_id(id) + .with_tag_column("tag1") + .with_tag_column("tag2") + .with_i64_field_column("field") + .with_time_column() + } + + fn sort_key() -> SortKey { + SortKeyBuilder::new() + .with_col("tag2") + .with_col("tag1") + .with_col(TIME_COLUMN_NAME) + .build() + } +} diff --git a/iox_query/src/physical_optimizer/combine_chunks.rs b/iox_query/src/physical_optimizer/combine_chunks.rs new file mode 100644 index 0000000..d09681e --- /dev/null +++ b/iox_query/src/physical_optimizer/combine_chunks.rs @@ -0,0 +1,436 @@ +use std::sync::Arc; + +use arrow::compute::SortOptions; +use datafusion::{ + common::{ + plan_err, + tree_node::{Transformed, TreeNode}, + }, + config::ConfigOptions, + error::{DataFusionError, Result}, + physical_optimizer::PhysicalOptimizerRule, + physical_plan::{union::UnionExec, ExecutionPlan}, +}; +use observability_deps::tracing::trace; +use schema::TIME_COLUMN_NAME; + +use crate::{ + physical_optimizer::{ + chunk_extraction::extract_chunks, + sort::util::{collect_statistics_min_max, sort_by_value_ranges}, + }, + provider::chunks_to_physical_nodes, +}; + +/// Collects [`QueryChunk`]s and re-creates a appropriate physical nodes. +/// +/// Invariants of inputs of the union: +/// 1. They do not overlap on time ranges (done in previous step: TimeSplit) +/// 2. Each input of the union is either with_chunks or other_plans. +/// - An input with_chunks is a plan that contains only (union of) ParquetExecs or RecordBatchesExec +/// - An input of other_plans is a plan that contains at least one node that is not a ParquetExec or +/// RecordBatchesExec or Union of them. Examples of those other nodes are FilterExec, DeduplicateExec, +/// ProjectionExec, etc. +// +/// Goals of this optimzation step: +/// i. Combine **possible** plans with_chunks into a single union +/// ii. - Keep the the combined plan non-overlapped on time ranges. This will likely help later optimization steps. +/// - If time ranges cannot be computed, combine all plans with_chunks into a single union. +/// +/// Example: w = with_chunks, o = other_plans +/// Input: |--P1 w --| |--P2 w --| |-- P3 o --| |-- P4 w --| |-- P5 w --| |-- P6 o --| |--P7 w --| +/// Output when time ranges can be computed: Only two sets of plans that are combined: [P1, P2], [P4, P5] +/// |------ P1 & P2 w ----| |-- P3 o --| |------ P4 & P5 w ------| |-- P6 o --| |--P7 w --| +/// Output when time ranges cannot be computed: all plans with_chunks are combined into a single union +/// |-------------------------- P1, P2, P4, P5, P7 w -------------------------------------| +/// |-- P3 o --| |-- P6 o --| +/// +/// +/// This is mostly useful after multiple re-arrangements (e.g. [`PartitionSplit`]-[`TimeSplit`]-[`RemoveDedup`]) created +/// a bunch of freestanding chunks that can be re-arranged into more packed, more efficient physical nodes. +/// +/// +/// [`PartitionSplit`]: super::dedup::partition_split::PartitionSplit +/// [`QueryChunk`]: crate::QueryChunk +/// [`RemoveDedup`]: super::dedup::remove_dedup::RemoveDedup +/// [`TimeSplit`]: super::dedup::time_split::TimeSplit +#[derive(Debug, Default)] +pub struct CombineChunks; + +impl PhysicalOptimizerRule for CombineChunks { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + plan.transform_up(&|plan| { + if let Some(union_exec) = plan.as_any().downcast_ref::() { + // sort and group the inputs by time range + let inputs = union_exec.inputs(); + // We only need to ensure the input are sorted by time range, + // any order is fine and hence we choose to go with ASC here + let groups = sort_and_group_plans( + inputs.clone(), + TIME_COLUMN_NAME, + SortOptions { + descending: false, + nulls_first: false, + }, + )?; + + // combine plans from each group + let plans = groups + .into_iter() + .map(|group| combine_plans(group, config)) + .collect::>>()? + .into_iter() + .flatten() + .collect::>(); + + let final_union = UnionExec::new(plans); + trace!(?final_union, "-------- final union"); + return Ok(Transformed::Yes(Arc::new(final_union))); + } + + Ok(Transformed::No(plan)) + }) + } + + fn name(&self) -> &str { + "combine_chunks" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// Sort the given plans on the given column name and a given sort order. +/// +/// Then group them into non-overlapped groups based on the ranges of the given column, and return the groups. +/// +/// # Input Invariants +/// - Plans do not overlap on the given column +/// +/// # Output Invariants +/// - Plans in the same group do not overlap on the given column +/// -The groups do not overlap on the given column +/// +/// # Example +/// Input: +/// +/// ```text +/// 7 plans with value ranges : |--P1 w --| |--P2 w --| |-- P3 o --| |-- P4 w --| |-- P5 w --| |-- P6 o --| |--P7 w --| +/// ``` +/// +/// Output: +/// +/// ```text +/// 5 groups: [P1, P2], [P3], [P4, P5], [P6], [P7] +/// ``` +fn sort_and_group_plans( + plans: Vec>, + col_name: &str, + sort_options: SortOptions, +) -> Result>>> { + if plans.len() <= 1 { + return Ok(vec![plans]); + } + + let Some(value_ranges) = collect_statistics_min_max(&plans, col_name)? else { + // No statistics to sort and group the plans. + // Return all plans in the same group + trace!("-------- combine chunks - cannot collect statistics min max for column {col_name}"); + return Ok(vec![plans]); + }; + + // Sort the plans by their value ranges + trace!("-------- value_ranges: {:?}", value_ranges); + let Some(plans_value_ranges) = sort_by_value_ranges(plans.clone(), value_ranges, sort_options)? + else { + // The inputs are not being sorted by value ranges, cannot group them + // Return all plans in the same group + trace!("-------- inputs are not sorted by value ranges. No optimization"); + return Ok(vec![plans]); + }; + + // Group plans that can be combined + let plans = plans_value_ranges.plans; + let mut final_groups = Vec::with_capacity(plans.len()); + let mut combinable_plans = Vec::new(); + for plan in plans { + if extract_chunks(plan.as_ref()).is_some() { + combinable_plans.push(plan); + } else { + if !combinable_plans.is_empty() { + final_groups.push(combinable_plans); + combinable_plans = Vec::new(); + } + final_groups.push(vec![plan]); + } + } + + if !combinable_plans.is_empty() { + final_groups.push(combinable_plans); + } + + Ok(final_groups) +} + +/// Combine the given plans with chunks into a single union. The other plans stay as is. +fn combine_plans( + plans: Vec>, + config: &ConfigOptions, +) -> Result>> { + let (inputs_with_chunks, inputs_other): (Vec<_>, Vec<_>) = plans + .iter() + .cloned() + .partition(|plan| extract_chunks(plan.as_ref()).is_some()); + + if inputs_with_chunks.is_empty() { + return Ok(plans); + } + let union_of_chunks = UnionExec::new(inputs_with_chunks); + + if let Some((schema, chunks, output_sort_key)) = extract_chunks(&union_of_chunks) { + let union_of_chunks = chunks_to_physical_nodes( + &schema, + output_sort_key.as_ref(), + chunks, + config.execution.target_partitions, + ); + let Some(union_of_chunks) = union_of_chunks.as_any().downcast_ref::() else { + return plan_err!("Expected chunks_to_physical_nodes to produce UnionExec but got {union_of_chunks:?}"); + }; + + // return other_plans and the union_of_chunks + let plans = union_of_chunks + .inputs() + .iter() + .cloned() + .chain(inputs_other) + .collect(); + return Ok(plans); + } + + Ok(plans) +} + +#[cfg(test)] +mod tests { + use datafusion::{ + physical_plan::{expressions::Literal, filter::FilterExec, union::UnionExec}, + scalar::ScalarValue, + }; + + use crate::{physical_optimizer::test_util::OptimizationTest, test::TestChunk, QueryChunk}; + + use super::*; + + #[test] + fn test_combine_single_union_tree() { + let chunk1 = TestChunk::new("table") + .with_id(1) + .with_time_column_with_stats(Some(1), Some(2)); + let chunk2 = TestChunk::new("table") + .with_id(2) + .with_dummy_parquet_file() + .with_time_column_with_stats(Some(3), Some(4)); + let chunk3 = TestChunk::new("table") + .with_id(3) + .with_time_column_with_stats(Some(5), Some(6)); + let chunk4 = TestChunk::new("table") + .with_id(4) + .with_dummy_parquet_file() + .with_time_column_with_stats(Some(7), Some(8)); + let chunk5 = TestChunk::new("table") + .with_id(5) + .with_dummy_parquet_file() + .with_time_column_with_stats(Some(9), Some(10)); + let schema = chunk1.schema().as_arrow(); + let plan = Arc::new(UnionExec::new(vec![ + chunks_to_physical_nodes(&schema, None, vec![Arc::new(chunk1), Arc::new(chunk2)], 2), + chunks_to_physical_nodes( + &schema, + None, + vec![Arc::new(chunk3), Arc::new(chunk4), Arc::new(chunk5)], + 2, + ), + ])); + let opt = CombineChunks; + let mut config = ConfigOptions::default(); + config.execution.target_partitions = 2; + insta::assert_yaml_snapshot!( + OptimizationTest::new_with_config(plan, opt, &config), + @r###" + --- + input: + - " UnionExec" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[time]" + - " ParquetExec: file_groups={1 group: [[2.parquet]]}, projection=[time]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[time]" + - " ParquetExec: file_groups={2 groups: [[4.parquet], [5.parquet]]}, projection=[time]" + output: + Ok: + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[time]" + - " ParquetExec: file_groups={2 groups: [[2.parquet, 5.parquet], [4.parquet]]}, projection=[time]" + "### + ); + } + + #[test] + fn test_only_combine_contiguous_arms() { + let chunk1 = TestChunk::new("table") + .with_id(1) + .with_dummy_parquet_file() + .with_time_column_with_stats(Some(1), Some(2)); + let chunk2 = TestChunk::new("table") + .with_id(2) + .with_dummy_parquet_file() + .with_time_column_with_stats(Some(3), Some(4)); + let chunk3 = TestChunk::new("table") + .with_id(3) + .with_dummy_parquet_file() + .with_time_column_with_stats(Some(5), Some(6)); + let chunk4 = TestChunk::new("table") + .with_id(4) + .with_dummy_parquet_file() + .with_time_column_with_stats(Some(7), Some(8)); + let schema = chunk1.schema().as_arrow(); + let plan = Arc::new(UnionExec::new(vec![ + chunks_to_physical_nodes(&schema, None, vec![Arc::new(chunk1)], 2), + chunks_to_physical_nodes(&schema, None, vec![Arc::new(chunk2)], 2), + Arc::new( + FilterExec::try_new( + Arc::new(Literal::new(ScalarValue::from(false))), + chunks_to_physical_nodes(&schema, None, vec![Arc::new(chunk3)], 2), + ) + .unwrap(), + ), + chunks_to_physical_nodes(&schema, None, vec![Arc::new(chunk4)], 2), + ])); + let opt = CombineChunks; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " UnionExec" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}, projection=[time]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[2.parquet]]}, projection=[time]" + - " FilterExec: false" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[3.parquet]]}, projection=[time]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[4.parquet]]}, projection=[time]" + output: + Ok: + - " UnionExec" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[time]" + - " FilterExec: false" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[3.parquet]]}, projection=[time]" + - " ParquetExec: file_groups={1 group: [[4.parquet]]}, projection=[time]" + "### + ); + } + + #[test] + fn test_combine_some_union_arms() { + let chunk1 = TestChunk::new("table").with_id(1).with_dummy_parquet_file(); + let chunk2 = TestChunk::new("table").with_id(1).with_dummy_parquet_file(); + let chunk3 = TestChunk::new("table").with_id(1).with_dummy_parquet_file(); + let schema = chunk1.schema().as_arrow(); + let plan = Arc::new(UnionExec::new(vec![ + chunks_to_physical_nodes(&schema, None, vec![Arc::new(chunk1)], 2), + chunks_to_physical_nodes(&schema, None, vec![Arc::new(chunk2)], 2), + Arc::new( + FilterExec::try_new( + Arc::new(Literal::new(ScalarValue::from(false))), + chunks_to_physical_nodes(&schema, None, vec![Arc::new(chunk3)], 2), + ) + .unwrap(), + ), + ])); + let opt = CombineChunks; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " UnionExec" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}" + - " FilterExec: false" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}" + output: + Ok: + - " UnionExec" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [1.parquet]]}" + - " FilterExec: false" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}" + "### + ); + } + + #[test] + fn test_no_chunks() { + let chunk1 = TestChunk::new("table").with_id(1); + let schema = chunk1.schema().as_arrow(); + let plan = chunks_to_physical_nodes(&schema, None, vec![], 2); + let opt = CombineChunks; + let mut config = ConfigOptions::default(); + config.execution.target_partitions = 2; + insta::assert_yaml_snapshot!( + OptimizationTest::new_with_config(plan, opt, &config), + @r###" + --- + input: + - " EmptyExec" + output: + Ok: + - " EmptyExec" + "### + ); + } + + #[test] + fn test_no_valid_arms() { + let chunk1 = TestChunk::new("table").with_id(1); + let schema = chunk1.schema().as_arrow(); + let plan = Arc::new(UnionExec::new(vec![Arc::new( + FilterExec::try_new( + Arc::new(Literal::new(ScalarValue::from(false))), + chunks_to_physical_nodes(&schema, None, vec![Arc::new(chunk1)], 2), + ) + .unwrap(), + )])); + let opt = CombineChunks; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " UnionExec" + - " FilterExec: false" + - " UnionExec" + - " RecordBatchesExec: chunks=1" + output: + Ok: + - " UnionExec" + - " FilterExec: false" + - " UnionExec" + - " RecordBatchesExec: chunks=1" + "### + ); + } +} diff --git a/iox_query/src/physical_optimizer/dedup/dedup_null_columns.rs b/iox_query/src/physical_optimizer/dedup/dedup_null_columns.rs new file mode 100644 index 0000000..341ae47 --- /dev/null +++ b/iox_query/src/physical_optimizer/dedup/dedup_null_columns.rs @@ -0,0 +1,249 @@ +use std::{collections::HashSet, sync::Arc}; + +use datafusion::{ + common::tree_node::{Transformed, TreeNode}, + config::ConfigOptions, + error::Result, + physical_optimizer::PhysicalOptimizerRule, + physical_plan::ExecutionPlan, +}; +use schema::{sort::SortKeyBuilder, TIME_COLUMN_NAME}; + +use crate::{ + physical_optimizer::chunk_extraction::extract_chunks, + provider::{chunks_to_physical_nodes, DeduplicateExec}, + util::arrow_sort_key_exprs, +}; + +/// Determine sort key set of [`DeduplicateExec`] by elimating all-NULL columns. +/// +/// This finds a good sort key for [`DeduplicateExec`] based on the [`QueryChunk`]s covered by the deduplication. +/// +/// We assume that, columns that are NOT present in any chunks and hence are only created as pure NULL-columns are +/// not relevant for deduplication since they are effectively constant. +/// +/// +/// [`QueryChunk`]: crate::QueryChunk +#[derive(Debug, Default)] +pub struct DedupNullColumns; + +impl PhysicalOptimizerRule for DedupNullColumns { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + plan.transform_up(&|plan| { + let plan_any = plan.as_any(); + + if let Some(dedup_exec) = plan_any.downcast_ref::() { + let mut children = dedup_exec.children(); + assert_eq!(children.len(), 1); + let child = children.remove(0); + let Some((schema, chunks, _output_sort_key)) = extract_chunks(child.as_ref()) + else { + return Ok(Transformed::No(plan)); + }; + + let pk_cols = dedup_exec.sort_columns(); + + let mut used_pk_cols = HashSet::new(); + for chunk in &chunks { + for (_type, field) in chunk.schema().iter() { + if pk_cols.contains(field.name().as_str()) { + used_pk_cols.insert(field.name().as_str()); + } + } + } + + let mut used_pk_cols = used_pk_cols.into_iter().collect::>(); + used_pk_cols.sort_by_key(|col| (*col == TIME_COLUMN_NAME, *col)); + + let mut sort_key_builder = SortKeyBuilder::new(); + for col in used_pk_cols { + sort_key_builder = sort_key_builder.with_col(col); + } + + let sort_key = sort_key_builder.build(); + let child = chunks_to_physical_nodes( + &schema, + (!sort_key.is_empty()).then_some(&sort_key), + chunks, + config.execution.target_partitions, + ); + + let sort_exprs = arrow_sort_key_exprs(&sort_key, &schema); + return Ok(Transformed::Yes(Arc::new(DeduplicateExec::new( + child, + sort_exprs, + dedup_exec.use_chunk_order_col(), + )))); + } + + Ok(Transformed::No(plan)) + }) + } + + fn name(&self) -> &str { + "dedup_null_columns" + } + + fn schema_check(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use schema::SchemaBuilder; + + use crate::{ + physical_optimizer::{ + dedup::test_util::{chunk, dedup_plan, dedup_plan_with_chunk_order_col}, + test_util::OptimizationTest, + }, + test::TestChunk, + QueryChunk, + }; + + use super::*; + + #[test] + fn test_no_chunks() { + let schema = chunk(1).schema().clone(); + let plan = dedup_plan(schema, vec![]); + let opt = DedupNullColumns; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " EmptyExec" + output: + Ok: + - " DeduplicateExec: []" + - " EmptyExec" + "### + ); + } + + #[test] + fn test_single_chunk_all_cols() { + let chunk = chunk(1).with_dummy_parquet_file(); + let schema = chunk.schema().clone(); + let plan = dedup_plan(schema, vec![chunk]); + let opt = DedupNullColumns; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}, projection=[field, tag1, tag2, time]" + output: + Ok: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}, projection=[field, tag1, tag2, time]" + "### + ); + } + + #[test] + fn test_single_chunk_schema_has_chunk_order_col() { + let chunk = chunk(1).with_dummy_parquet_file(); + let schema = chunk.schema().clone(); + let plan = dedup_plan_with_chunk_order_col(schema, vec![chunk]); + let opt = DedupNullColumns; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}, projection=[field, tag1, tag2, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + output: + Ok: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}, projection=[field, tag1, tag2, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + } + + #[test] + fn test_single_chunk_misses_pk_cols() { + let chunk = TestChunk::new("table") + .with_id(1) + .with_tag_column("tag1") + .with_dummy_parquet_file(); + let schema = SchemaBuilder::new() + .tag("tag1") + .tag("tag2") + .tag("zzz") + .timestamp() + .build() + .unwrap(); + let plan = dedup_plan(schema, vec![chunk]); + let opt = DedupNullColumns; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@0 ASC,tag2@1 ASC,zzz@2 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}, projection=[tag1, tag2, zzz, time]" + output: + Ok: + - " DeduplicateExec: [tag1@0 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}, projection=[tag1, tag2, zzz, time]" + "### + ); + } + + #[test] + fn test_two_chunks() { + let chunk1 = TestChunk::new("table") + .with_id(1) + .with_tag_column("tag1") + .with_tag_column("tag2") + .with_time_column() + .with_dummy_parquet_file(); + let chunk2 = TestChunk::new("table") + .with_id(2) + .with_tag_column("tag1") + .with_tag_column("tag3") + .with_time_column() + .with_dummy_parquet_file(); + let schema = SchemaBuilder::new() + .tag("tag1") + .tag("tag2") + .tag("tag3") + .tag("tag4") + .timestamp() + .build() + .unwrap(); + let plan = dedup_plan(schema, vec![chunk1, chunk2]); + let opt = DedupNullColumns; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@0 ASC,tag2@1 ASC,tag3@2 ASC,tag4@3 ASC,time@4 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[tag1, tag2, tag3, tag4, time]" + output: + Ok: + - " DeduplicateExec: [tag1@0 ASC,tag2@1 ASC,tag3@2 ASC,time@4 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[tag1, tag2, tag3, tag4, time]" + "### + ); + } +} diff --git a/iox_query/src/physical_optimizer/dedup/dedup_sort_order.rs b/iox_query/src/physical_optimizer/dedup/dedup_sort_order.rs new file mode 100644 index 0000000..c4b3924 --- /dev/null +++ b/iox_query/src/physical_optimizer/dedup/dedup_sort_order.rs @@ -0,0 +1,636 @@ +use std::{cmp::Reverse, sync::Arc}; + +use arrow::compute::SortOptions; +use datafusion::{ + common::tree_node::{Transformed, TreeNode}, + config::ConfigOptions, + error::Result, + physical_optimizer::PhysicalOptimizerRule, + physical_plan::ExecutionPlan, +}; +use indexmap::IndexSet; +use schema::{sort::SortKeyBuilder, TIME_COLUMN_NAME}; + +use crate::{ + physical_optimizer::chunk_extraction::extract_chunks, + provider::{chunks_to_physical_nodes, DeduplicateExec}, + util::arrow_sort_key_exprs, + CHUNK_ORDER_COLUMN_NAME, +}; + +/// Determine sort key order of [`DeduplicateExec`]. +/// +/// This finds a cheap sort key order for [`DeduplicateExec`] based on the [`QueryChunk`]s covered by the deduplication. +/// This means that the sort key of the [`DeduplicateExec`] should be as close as possible to the pre-sorted chunks to +/// avoid resorting. If all chunks are pre-sorted (or not sorted at all), this is basically the joined merged sort key +/// of all of them. If the chunks do not agree on a single sort order[^different_orders], then we use a vote-based +/// system where we column-by-column pick the sort key order in the hope that this does the least harm. +/// +/// The produces sort key MUST be the same set of columns as before, i.e. this rule does NOT change the column set, it +/// only changes the order. +/// +/// We assume that the order of the sort key passed to [`DeduplicateExec`] is not relevant for correctness. +/// +/// This optimizer makes no assumption about how the ingester or compaction tier work or how chunks relate to each +/// other. As a consequence, it does NOT use the partition sort key. +/// +/// +/// [^different_orders]: In an ideal system, all chunks that have a sort order should agree on a single one. However we +/// want to avoid that the querier disintegrates when the ingester or compactor are buggy or when manual +/// interventions (like manual file creations) insert files that are slightly off. +/// +/// +/// [`QueryChunk`]: crate::QueryChunk +#[derive(Debug, Default)] +pub struct DedupSortOrder; + +impl PhysicalOptimizerRule for DedupSortOrder { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + plan.transform_up(&|plan| { + let plan_any = plan.as_any(); + + if let Some(dedup_exec) = plan_any.downcast_ref::() { + let mut children = dedup_exec.children(); + assert_eq!(children.len(), 1); + let child = children.remove(0); + let Some((schema, chunks, _output_sort_key)) = extract_chunks(child.as_ref()) + else { + return Ok(Transformed::No(plan)); + }; + + let mut chunk_sort_keys: Vec> = chunks + .iter() + .map(|chunk| { + chunk + .sort_key() + .map(|sort_key| { + sort_key + .iter() + .map(|(col, opts)| { + assert_eq!(opts, &SortOptions::default()); + col.as_ref() + }) + .collect() + }) + .unwrap_or_default() + }) + .collect(); + + let mut quorum_sort_key_builder = SortKeyBuilder::default(); + let mut todo_pk_columns = dedup_exec.sort_columns(); + todo_pk_columns.remove(CHUNK_ORDER_COLUMN_NAME); + while !todo_pk_columns.is_empty() { + let candidate_counts = todo_pk_columns.iter().copied().map(|col| { + let count = chunk_sort_keys + .iter() + .filter(|sort_key| { + match sort_key.get_index_of(col) { + Some(0) => { + // Column next in sort order from this chunks PoV. This is good. + true + } + Some(_) => { + // Column part of the sort order but we have at least one more column before + // that. Try to avoid an expensive resort for this chunk. + false + } + None => { + // Column is not in the sort order of this chunk at all. Hence we can place it + // everywhere in the quorum sort key w/o having to worry about this particular + // chunk. + true + } + } + }) + .count(); + (col, count) + }); + let candidate_counts = sorted( + candidate_counts + .into_iter() + .map(|(col, count)| (Reverse(count), col == TIME_COLUMN_NAME, col)), + ); + let next_key = candidate_counts.first().expect("all TODO cols inserted").2; + + for chunk_sort_key in &mut chunk_sort_keys { + chunk_sort_key.shift_remove_full(next_key); + } + + let was_present = todo_pk_columns.remove(next_key); + assert!(was_present); + + quorum_sort_key_builder = quorum_sort_key_builder.with_col(next_key); + } + + let quorum_sort_key = quorum_sort_key_builder.build(); + let child = chunks_to_physical_nodes( + &schema, + (!quorum_sort_key.is_empty()).then_some(&quorum_sort_key), + chunks, + config.execution.target_partitions, + ); + + let sort_exprs = arrow_sort_key_exprs(&quorum_sort_key, &schema); + return Ok(Transformed::Yes(Arc::new(DeduplicateExec::new( + child, + sort_exprs, + dedup_exec.use_chunk_order_col(), + )))); + } + + Ok(Transformed::No(plan)) + }) + } + + fn name(&self) -> &str { + "dedup_sort_order" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// Collect items into a sorted vector. +fn sorted(it: impl IntoIterator) -> Vec +where + T: Ord, +{ + let mut items = it.into_iter().collect::>(); + items.sort(); + items +} + +#[cfg(test)] +mod tests { + use schema::{sort::SortKey, SchemaBuilder, TIME_COLUMN_NAME}; + + use crate::{ + physical_optimizer::{ + dedup::test_util::{chunk, dedup_plan, dedup_plan_with_chunk_order_col}, + test_util::OptimizationTest, + }, + test::TestChunk, + QueryChunk, + }; + + use super::*; + + #[test] + fn test_no_chunks() { + let schema = chunk(1).schema().clone(); + let plan = dedup_plan(schema, vec![]); + let opt = DedupSortOrder; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " EmptyExec" + output: + Ok: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " EmptyExec" + "### + ); + } + + #[test] + fn test_single_chunk_no_sort_key() { + let chunk = chunk(1).with_dummy_parquet_file(); + let schema = chunk.schema().clone(); + let plan = dedup_plan(schema, vec![chunk]); + let opt = DedupSortOrder; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}, projection=[field, tag1, tag2, time]" + output: + Ok: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}, projection=[field, tag1, tag2, time]" + "### + ); + } + + #[test] + fn test_single_chunk_order() { + let chunk = chunk(1) + .with_dummy_parquet_file() + .with_sort_key(SortKey::from_columns([ + Arc::from("tag2"), + Arc::from("tag1"), + Arc::from(TIME_COLUMN_NAME), + ])); + let schema = chunk.schema().clone(); + let plan = dedup_plan(schema, vec![chunk]); + let opt = DedupSortOrder; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}, projection=[field, tag1, tag2, time], output_ordering=[tag2@2 ASC, tag1@1 ASC, time@3 ASC]" + output: + Ok: + - " DeduplicateExec: [tag2@2 ASC,tag1@1 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}, projection=[field, tag1, tag2, time], output_ordering=[tag2@2 ASC, tag1@1 ASC, time@3 ASC]" + "### + ); + } + + #[test] + fn test_single_chunk_with_chunk_order_col() { + let chunk = chunk(1) + .with_dummy_parquet_file() + .with_sort_key(SortKey::from_columns([ + Arc::from("tag2"), + Arc::from("tag1"), + Arc::from(TIME_COLUMN_NAME), + ])); + let schema = chunk.schema().clone(); + let plan = dedup_plan_with_chunk_order_col(schema, vec![chunk]); + let opt = DedupSortOrder; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}, projection=[field, tag1, tag2, time, __chunk_order], output_ordering=[tag2@2 ASC, tag1@1 ASC, time@3 ASC, __chunk_order@4 ASC]" + output: + Ok: + - " DeduplicateExec: [tag2@2 ASC,tag1@1 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}, projection=[field, tag1, tag2, time, __chunk_order], output_ordering=[tag2@2 ASC, tag1@1 ASC, time@3 ASC, __chunk_order@4 ASC]" + "### + ); + } + + #[test] + fn test_unusual_time_order() { + let chunk = chunk(1) + .with_dummy_parquet_file() + .with_sort_key(SortKey::from_columns([ + Arc::from(TIME_COLUMN_NAME), + Arc::from("tag1"), + Arc::from("tag2"), + ])); + let schema = chunk.schema().clone(); + let plan = dedup_plan(schema, vec![chunk]); + let opt = DedupSortOrder; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}, projection=[field, tag1, tag2, time], output_ordering=[time@3 ASC, tag1@1 ASC, tag2@2 ASC]" + output: + Ok: + - " DeduplicateExec: [time@3 ASC,tag1@1 ASC,tag2@2 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}, projection=[field, tag1, tag2, time], output_ordering=[time@3 ASC, tag1@1 ASC, tag2@2 ASC]" + "### + ); + } + + #[test] + fn test_single_chunk_time_always_included() { + let chunk = chunk(1) + .with_tag_column("zzz") + .with_dummy_parquet_file() + .with_sort_key(SortKey::from_columns([ + Arc::from("tag2"), + Arc::from("tag1"), + ])); + let schema = chunk.schema().clone(); + let plan = dedup_plan(schema, vec![chunk]); + let opt = DedupSortOrder; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,zzz@4 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}, projection=[field, tag1, tag2, time, zzz], output_ordering=[tag2@2 ASC, tag1@1 ASC]" + output: + Ok: + - " DeduplicateExec: [tag2@2 ASC,tag1@1 ASC,zzz@4 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}, projection=[field, tag1, tag2, time, zzz]" + "### + ); + } + + #[test] + fn test_single_chunk_misses_pk_cols() { + let chunk = TestChunk::new("table") + .with_id(1) + .with_tag_column("tag1") + .with_dummy_parquet_file() + .with_sort_key(SortKey::from_columns([Arc::from("tag1")])); + let schema = SchemaBuilder::new() + .tag("tag1") + .tag("tag2") + .tag("zzz") + .timestamp() + .build() + .unwrap(); + let plan = dedup_plan(schema, vec![chunk]); + let opt = DedupSortOrder; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@0 ASC,tag2@1 ASC,zzz@2 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}, projection=[tag1, tag2, zzz, time], output_ordering=[tag1@0 ASC]" + output: + Ok: + - " DeduplicateExec: [tag1@0 ASC,tag2@1 ASC,zzz@2 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}, projection=[tag1, tag2, zzz, time], output_ordering=[tag1@0 ASC, tag2@1 ASC, zzz@2 ASC, time@3 ASC]" + "### + ); + } + + #[test] + fn test_two_chunks_break_even_by_col_name() { + let chunk1 = chunk(1) + .with_dummy_parquet_file() + .with_sort_key(SortKey::from_columns([ + Arc::from("tag1"), + Arc::from("tag2"), + Arc::from(TIME_COLUMN_NAME), + ])); + let chunk2 = chunk(2) + .with_dummy_parquet_file() + .with_sort_key(SortKey::from_columns([ + Arc::from("tag2"), + Arc::from("tag1"), + Arc::from(TIME_COLUMN_NAME), + ])); + let schema = chunk1.schema().clone(); + let plan = dedup_plan(schema, vec![chunk1, chunk2]); + let opt = DedupSortOrder; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[field, tag1, tag2, time]" + output: + Ok: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[field, tag1, tag2, time]" + "### + ); + } + + #[test] + fn test_two_chunks_sorted_ranks_higher_than_not_sorted() { + let chunk1 = chunk(1) + .with_dummy_parquet_file() + .with_sort_key(SortKey::from_columns([ + Arc::from("tag2"), + Arc::from(TIME_COLUMN_NAME), + ])); + let chunk2 = chunk(2) + .with_dummy_parquet_file() + .with_sort_key(SortKey::from_columns([ + Arc::from("tag2"), + Arc::from("tag1"), + Arc::from(TIME_COLUMN_NAME), + ])); + let schema = chunk1.schema().clone(); + let plan = dedup_plan(schema, vec![chunk1, chunk2]); + let opt = DedupSortOrder; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[field, tag1, tag2, time], output_ordering=[tag2@2 ASC, tag1@1 ASC, time@3 ASC]" + output: + Ok: + - " DeduplicateExec: [tag2@2 ASC,tag1@1 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[field, tag1, tag2, time]" + "### + ); + } + + #[test] + fn test_two_chunks_one_without_sort_key() { + let chunk1 = chunk(1) + .with_dummy_parquet_file() + .with_sort_key(SortKey::from_columns([ + Arc::from("tag2"), + Arc::from("tag1"), + Arc::from(TIME_COLUMN_NAME), + ])); + let chunk2 = chunk(2).with_dummy_parquet_file(); + let schema = chunk1.schema().clone(); + let plan = dedup_plan(schema, vec![chunk1, chunk2]); + let opt = DedupSortOrder; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[field, tag1, tag2, time]" + output: + Ok: + - " DeduplicateExec: [tag2@2 ASC,tag1@1 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[field, tag1, tag2, time]" + "### + ); + } + + #[test] + fn test_three_chunks_different_subsets() { + let chunk1 = TestChunk::new("table") + .with_id(1) + .with_tag_column("tag1") + .with_tag_column("tag2") + .with_time_column() + .with_dummy_parquet_file() + .with_sort_key(SortKey::from_columns([ + Arc::from("tag2"), + Arc::from("tag1"), + Arc::from(TIME_COLUMN_NAME), + ])); + let chunk2 = TestChunk::new("table") + .with_id(2) + .with_tag_column("tag1") + .with_tag_column("tag3") + .with_time_column() + .with_dummy_parquet_file() + .with_sort_key(SortKey::from_columns([ + Arc::from("tag3"), + Arc::from("tag1"), + Arc::from(TIME_COLUMN_NAME), + ])); + let chunk3 = TestChunk::new("table") + .with_id(3) + .with_tag_column("tag1") + .with_tag_column("tag2") + .with_tag_column("tag3") + .with_time_column() + .with_dummy_parquet_file() + .with_sort_key(SortKey::from_columns([ + Arc::from("tag2"), + Arc::from("tag3"), + Arc::from("tag1"), + Arc::from(TIME_COLUMN_NAME), + ])); + let schema = chunk3.schema().clone(); + let plan = dedup_plan(schema, vec![chunk1, chunk2, chunk3]); + let opt = DedupSortOrder; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@0 ASC,tag2@1 ASC,tag3@2 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={2 groups: [[1.parquet, 3.parquet], [2.parquet]]}, projection=[tag1, tag2, tag3, time]" + output: + Ok: + - " DeduplicateExec: [tag2@1 ASC,tag3@2 ASC,tag1@0 ASC,time@3 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={3 groups: [[1.parquet], [3.parquet], [2.parquet]]}, projection=[tag1, tag2, tag3, time], output_ordering=[tag2@1 ASC, tag3@2 ASC, tag1@0 ASC, time@3 ASC]" + "### + ); + } + + #[test] + fn test_three_chunks_single_chunk_has_extra_col1() { + let chunk1 = TestChunk::new("table") + .with_id(1) + .with_tag_column("tag1") + .with_time_column() + .with_dummy_parquet_file() + .with_sort_key(SortKey::from_columns([ + Arc::from("tag1"), + Arc::from(TIME_COLUMN_NAME), + ])); + let chunk2 = TestChunk::new("table") + .with_id(2) + .with_tag_column("tag1") + .with_time_column() + .with_dummy_parquet_file() + .with_sort_key(SortKey::from_columns([ + Arc::from("tag1"), + Arc::from(TIME_COLUMN_NAME), + ])); + let chunk3 = TestChunk::new("table") + .with_id(3) + .with_tag_column("tag1") + .with_tag_column("tag2") + .with_time_column() + .with_dummy_parquet_file() + .with_sort_key(SortKey::from_columns([ + Arc::from("tag2"), + Arc::from("tag1"), + Arc::from(TIME_COLUMN_NAME), + ])); + let schema = chunk3.schema().clone(); + let plan = dedup_plan(schema, vec![chunk1, chunk2, chunk3]); + let opt = DedupSortOrder; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@0 ASC,tag2@1 ASC,time@2 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={2 groups: [[1.parquet, 3.parquet], [2.parquet]]}, projection=[tag1, tag2, time], output_ordering=[tag2@1 ASC, tag1@0 ASC, time@2 ASC]" + output: + Ok: + - " DeduplicateExec: [tag2@1 ASC,tag1@0 ASC,time@2 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={3 groups: [[1.parquet], [3.parquet], [2.parquet]]}, projection=[tag1, tag2, time], output_ordering=[tag2@1 ASC, tag1@0 ASC, time@2 ASC]" + "### + ); + } + + #[test] + fn test_three_chunks_single_chunk_has_extra_col2() { + let chunk1 = TestChunk::new("table") + .with_id(1) + .with_tag_column("tag1") + .with_tag_column("tag2") + .with_time_column() + .with_dummy_parquet_file() + .with_sort_key(SortKey::from_columns([ + Arc::from("tag1"), + Arc::from(TIME_COLUMN_NAME), + ])); + let chunk2 = TestChunk::new("table") + .with_id(2) + .with_tag_column("tag1") + .with_tag_column("tag2") + .with_time_column() + .with_dummy_parquet_file() + .with_sort_key(SortKey::from_columns([ + Arc::from("tag1"), + Arc::from(TIME_COLUMN_NAME), + ])); + let chunk3 = TestChunk::new("table") + .with_id(3) + .with_tag_column("tag1") + .with_tag_column("tag2") + .with_time_column() + .with_dummy_parquet_file() + .with_sort_key(SortKey::from_columns([ + Arc::from("tag2"), + Arc::from("tag1"), + Arc::from(TIME_COLUMN_NAME), + ])); + let schema = chunk3.schema().clone(); + let plan = dedup_plan(schema, vec![chunk1, chunk2, chunk3]); + let opt = DedupSortOrder; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@0 ASC,tag2@1 ASC,time@2 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={2 groups: [[1.parquet, 3.parquet], [2.parquet]]}, projection=[tag1, tag2, time], output_ordering=[tag2@1 ASC, tag1@0 ASC, time@2 ASC]" + output: + Ok: + - " DeduplicateExec: [tag2@1 ASC,tag1@0 ASC,time@2 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={3 groups: [[1.parquet], [3.parquet], [2.parquet]]}, projection=[tag1, tag2, time]" + "### + ); + } +} diff --git a/iox_query/src/physical_optimizer/dedup/mod.rs b/iox_query/src/physical_optimizer/dedup/mod.rs new file mode 100644 index 0000000..813cd3b --- /dev/null +++ b/iox_query/src/physical_optimizer/dedup/mod.rs @@ -0,0 +1,10 @@ +//! Optimizer passes concering de-duplication. + +pub mod dedup_null_columns; +pub mod dedup_sort_order; +pub mod partition_split; +pub mod remove_dedup; +pub mod time_split; + +#[cfg(test)] +mod test_util; diff --git a/iox_query/src/physical_optimizer/dedup/partition_split.rs b/iox_query/src/physical_optimizer/dedup/partition_split.rs new file mode 100644 index 0000000..386cd9c --- /dev/null +++ b/iox_query/src/physical_optimizer/dedup/partition_split.rs @@ -0,0 +1,287 @@ +use crate::{ + config::IoxConfigExt, + physical_optimizer::chunk_extraction::extract_chunks, + provider::{chunks_to_physical_nodes, DeduplicateExec}, + QueryChunk, +}; +use datafusion::{ + common::tree_node::{Transformed, TreeNode}, + config::ConfigOptions, + error::Result, + physical_optimizer::PhysicalOptimizerRule, + physical_plan::{union::UnionExec, ExecutionPlan}, +}; +use hashbrown::HashMap; +use observability_deps::tracing::warn; +use std::sync::Arc; + +/// Split de-duplication operations based on partitons. +/// +/// This should usually be more cost-efficient. +#[derive(Debug, Default)] +pub struct PartitionSplit; + +impl PhysicalOptimizerRule for PartitionSplit { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + plan.transform_up(&|plan| { + let plan_any = plan.as_any(); + + if let Some(dedup_exec) = plan_any.downcast_ref::() { + let mut children = dedup_exec.children(); + assert_eq!(children.len(), 1); + let child = children.remove(0); + let Some((schema, chunks, output_sort_key)) = extract_chunks(child.as_ref()) else { + return Ok(Transformed::No(plan)); + }; + + let mut chunks_by_partition: HashMap<_, Vec>> = + Default::default(); + for chunk in chunks { + chunks_by_partition + .entry(chunk.partition_id().clone()) + .or_default() + .push(chunk); + } + + // If there not multiple partitions (0 or 1), then this optimizer is a no-op. Signal that to the + // optimizer framework. + if chunks_by_partition.len() < 2 { + return Ok(Transformed::No(plan)); + } + + // Protect against degenerative plans + let max_dedup_partition_split = config + .extensions + .get::() + .cloned() + .unwrap_or_default() + .max_dedup_partition_split; + if chunks_by_partition.len() > max_dedup_partition_split { + warn!( + n_partitions = chunks_by_partition.len(), + max_dedup_partition_split, + "cannot split dedup operation based on partition, too many partitions" + ); + return Ok(Transformed::No(plan)); + } + + // ensure deterministic order + let mut chunks_by_partition = chunks_by_partition.into_iter().collect::>(); + chunks_by_partition.sort_by(|a, b| a.0.cmp(&b.0)); + + let out = UnionExec::new( + chunks_by_partition + .into_iter() + .map(|(_p_id, chunks)| { + Arc::new(DeduplicateExec::new( + chunks_to_physical_nodes( + &schema, + output_sort_key.as_ref(), + chunks, + config.execution.target_partitions, + ), + dedup_exec.sort_keys().to_vec(), + dedup_exec.use_chunk_order_col(), + )) as _ + }) + .collect(), + ); + return Ok(Transformed::Yes(Arc::new(out))); + } + + Ok(Transformed::No(plan)) + }) + } + + fn name(&self) -> &str { + "partition_split" + } + + fn schema_check(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::physical_optimizer::{ + dedup::test_util::{chunk, dedup_plan}, + test_util::OptimizationTest, + }; + use data_types::{PartitionHashId, PartitionId, TransitionPartitionId}; + + #[test] + fn test_no_chunks() { + let schema = chunk(1).schema().clone(); + let plan = dedup_plan(schema, vec![]); + let opt = PartitionSplit; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " EmptyExec" + output: + Ok: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " EmptyExec" + "### + ); + } + + #[test] + fn test_same_partition() { + let chunk1 = chunk(1); + let chunk2 = chunk(2); + let chunk3 = chunk(3).with_dummy_parquet_file(); + let schema = chunk1.schema().clone(); + let plan = dedup_plan(schema, vec![chunk1, chunk2, chunk3]); + let opt = PartitionSplit; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[field, tag1, tag2, time]" + - " ParquetExec: file_groups={1 group: [[3.parquet]]}, projection=[field, tag1, tag2, time]" + output: + Ok: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[field, tag1, tag2, time]" + - " ParquetExec: file_groups={1 group: [[3.parquet]]}, projection=[field, tag1, tag2, time]" + "### + ); + } + + #[test] + fn test_different_partitions() { + let chunk1 = chunk(1).with_partition(1); + let chunk2 = chunk(2).with_partition(2); + // use at least 3 parquet files for one of the two partitions to validate that `target_partitions` is forwared correctly + let chunk3 = chunk(3).with_dummy_parquet_file().with_partition(1); + let chunk4 = chunk(4).with_dummy_parquet_file().with_partition(2); + let chunk5 = chunk(5).with_dummy_parquet_file().with_partition(1); + let chunk6 = chunk(6).with_dummy_parquet_file().with_partition(1); + let schema = chunk1.schema().clone(); + let plan = dedup_plan(schema, vec![chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]); + let opt = PartitionSplit; + let mut config = ConfigOptions::default(); + config.execution.target_partitions = 2; + insta::assert_yaml_snapshot!( + OptimizationTest::new_with_config(plan, opt, &config), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[field, tag1, tag2, time]" + - " ParquetExec: file_groups={2 groups: [[3.parquet, 5.parquet], [4.parquet, 6.parquet]]}, projection=[field, tag1, tag2, time]" + output: + Ok: + - " UnionExec" + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time]" + - " ParquetExec: file_groups={2 groups: [[3.parquet, 6.parquet], [5.parquet]]}, projection=[field, tag1, tag2, time]" + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time]" + - " ParquetExec: file_groups={1 group: [[4.parquet]]}, projection=[field, tag1, tag2, time]" + "### + ); + } + + #[test] + fn test_different_partitions_with_and_without_hash_ids() { + // Partition without hash ID in the catalog + let legacy_partition_id = 1; + let legacy_transition_partition_id = + TransitionPartitionId::Deprecated(PartitionId::new(legacy_partition_id)); + + // Partition with hash ID in the catalog + let transition_partition_id = + TransitionPartitionId::Deterministic(PartitionHashId::arbitrary_for_testing()); + + let chunk1 = chunk(1).with_partition_id(legacy_transition_partition_id.clone()); + let chunk2 = chunk(2).with_partition_id(transition_partition_id.clone()); + + let chunk3 = chunk(3) + .with_dummy_parquet_file() + .with_partition_id(legacy_transition_partition_id.clone()); + let chunk4 = chunk(4) + .with_dummy_parquet_file() + .with_partition_id(transition_partition_id.clone()); + let chunk5 = chunk(5) + .with_dummy_parquet_file() + .with_partition_id(legacy_transition_partition_id.clone()); + let chunk6 = chunk(6) + .with_dummy_parquet_file() + .with_partition_id(legacy_transition_partition_id.clone()); + let schema = chunk1.schema().clone(); + let plan = dedup_plan(schema, vec![chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]); + let opt = PartitionSplit; + let mut config = ConfigOptions::default(); + config.execution.target_partitions = 2; + insta::assert_yaml_snapshot!( + OptimizationTest::new_with_config(plan, opt, &config), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[field, tag1, tag2, time]" + - " ParquetExec: file_groups={2 groups: [[3.parquet, 5.parquet], [4.parquet, 6.parquet]]}, projection=[field, tag1, tag2, time]" + output: + Ok: + - " UnionExec" + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time]" + - " ParquetExec: file_groups={2 groups: [[3.parquet, 6.parquet], [5.parquet]]}, projection=[field, tag1, tag2, time]" + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time]" + - " ParquetExec: file_groups={1 group: [[4.parquet]]}, projection=[field, tag1, tag2, time]" + "### + ); + } + + #[test] + fn test_max_split() { + let chunk1 = chunk(1).with_partition(1); + let chunk2 = chunk(2).with_partition(2); + let chunk3 = chunk(3).with_partition(3); + let schema = chunk1.schema().clone(); + let plan = dedup_plan(schema, vec![chunk1, chunk2, chunk3]); + let opt = PartitionSplit; + let mut config = ConfigOptions::default(); + config.extensions.insert(IoxConfigExt { + max_dedup_partition_split: 2, + ..Default::default() + }); + insta::assert_yaml_snapshot!( + OptimizationTest::new_with_config(plan, opt, &config), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=3, projection=[field, tag1, tag2, time]" + output: + Ok: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=3, projection=[field, tag1, tag2, time]" + "### + ); + } +} diff --git a/iox_query/src/physical_optimizer/dedup/remove_dedup.rs b/iox_query/src/physical_optimizer/dedup/remove_dedup.rs new file mode 100644 index 0000000..9558c5a --- /dev/null +++ b/iox_query/src/physical_optimizer/dedup/remove_dedup.rs @@ -0,0 +1,159 @@ +use std::sync::Arc; + +use datafusion::{ + common::tree_node::{Transformed, TreeNode}, + config::ConfigOptions, + error::Result, + physical_optimizer::PhysicalOptimizerRule, + physical_plan::ExecutionPlan, +}; + +use crate::{ + physical_optimizer::chunk_extraction::extract_chunks, + provider::{chunks_to_physical_nodes, DeduplicateExec}, +}; + +/// Removes de-duplication operation if there are at most 1 chunks and this chunk does NOT contain primary-key duplicates. +#[derive(Debug, Default)] +pub struct RemoveDedup; + +impl PhysicalOptimizerRule for RemoveDedup { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + plan.transform_up(&|plan| { + let plan_any = plan.as_any(); + + if let Some(dedup_exec) = plan_any.downcast_ref::() { + let mut children = dedup_exec.children(); + assert_eq!(children.len(), 1); + let child = children.remove(0); + let Some((schema, chunks, output_sort_key)) = extract_chunks(child.as_ref()) else { + return Ok(Transformed::No(plan)); + }; + + if (chunks.len() < 2) && chunks.iter().all(|c| !c.may_contain_pk_duplicates()) { + return Ok(Transformed::Yes(chunks_to_physical_nodes( + &schema, + output_sort_key.as_ref(), + chunks, + config.execution.target_partitions, + ))); + } + } + + Ok(Transformed::No(plan)) + }) + } + + fn name(&self) -> &str { + "remove_dedup" + } + + fn schema_check(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use crate::{ + physical_optimizer::{ + dedup::test_util::{chunk, dedup_plan}, + test_util::OptimizationTest, + }, + QueryChunk, + }; + + use super::*; + + #[test] + fn test_no_chunks() { + let schema = chunk(1).schema().clone(); + let plan = dedup_plan(schema, vec![]); + let opt = RemoveDedup; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " EmptyExec" + output: + Ok: + - " EmptyExec" + "### + ); + } + + #[test] + fn test_single_chunk_no_pk_dups() { + let chunk1 = chunk(1).with_may_contain_pk_duplicates(false); + let schema = chunk1.schema().clone(); + let plan = dedup_plan(schema, vec![chunk1]); + let opt = RemoveDedup; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time]" + output: + Ok: + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time]" + "### + ); + } + + #[test] + fn test_single_chunk_with_pk_dups() { + let chunk1 = chunk(1).with_may_contain_pk_duplicates(true); + let schema = chunk1.schema().clone(); + let plan = dedup_plan(schema, vec![chunk1]); + let opt = RemoveDedup; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time]" + output: + Ok: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time]" + "### + ); + } + + #[test] + fn test_multiple_chunks() { + let chunk1 = chunk(1).with_may_contain_pk_duplicates(false); + let chunk2 = chunk(2).with_may_contain_pk_duplicates(false); + let schema = chunk1.schema().clone(); + let plan = dedup_plan(schema, vec![chunk1, chunk2]); + let opt = RemoveDedup; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[field, tag1, tag2, time]" + output: + Ok: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[field, tag1, tag2, time]" + "### + ); + } +} diff --git a/iox_query/src/physical_optimizer/dedup/test_util.rs b/iox_query/src/physical_optimizer/dedup/test_util.rs new file mode 100644 index 0000000..20d8a99 --- /dev/null +++ b/iox_query/src/physical_optimizer/dedup/test_util.rs @@ -0,0 +1,62 @@ +use std::sync::Arc; + +use arrow::datatypes::{Fields, Schema as ArrowSchema}; +use datafusion::physical_plan::ExecutionPlan; +use schema::Schema; + +use crate::{ + chunk_order_field, + provider::{chunks_to_physical_nodes, DeduplicateExec}, + test::TestChunk, + util::arrow_sort_key_exprs, + QueryChunk, +}; + +pub fn dedup_plan(schema: Schema, chunks: Vec) -> Arc { + dedup_plan_impl(schema, chunks, false) +} + +pub fn dedup_plan_with_chunk_order_col( + schema: Schema, + chunks: Vec, +) -> Arc { + dedup_plan_impl(schema, chunks, true) +} + +fn dedup_plan_impl( + schema: Schema, + chunks: Vec, + use_chunk_order_col: bool, +) -> Arc { + let chunks = chunks + .into_iter() + .map(|c| Arc::new(c) as _) + .collect::>>(); + let arrow_schema = if use_chunk_order_col { + Arc::new(ArrowSchema::new( + schema + .as_arrow() + .fields + .iter() + .cloned() + .chain(std::iter::once(chunk_order_field())) + .collect::(), + )) + } else { + schema.as_arrow() + }; + let plan = chunks_to_physical_nodes(&arrow_schema, None, chunks, 2); + + let sort_key = schema::sort::SortKey::from_columns(schema.primary_key()); + let sort_exprs = arrow_sort_key_exprs(&sort_key, &plan.schema()); + Arc::new(DeduplicateExec::new(plan, sort_exprs, use_chunk_order_col)) +} + +pub fn chunk(id: u128) -> TestChunk { + TestChunk::new("table") + .with_id(id) + .with_tag_column("tag1") + .with_tag_column("tag2") + .with_i64_field_column("field") + .with_time_column() +} diff --git a/iox_query/src/physical_optimizer/dedup/time_split.rs b/iox_query/src/physical_optimizer/dedup/time_split.rs new file mode 100644 index 0000000..29acccb --- /dev/null +++ b/iox_query/src/physical_optimizer/dedup/time_split.rs @@ -0,0 +1,235 @@ +use std::sync::Arc; + +use datafusion::{ + common::tree_node::{Transformed, TreeNode}, + config::ConfigOptions, + error::Result, + physical_optimizer::PhysicalOptimizerRule, + physical_plan::{union::UnionExec, ExecutionPlan}, +}; +use observability_deps::tracing::warn; + +use crate::{ + config::IoxConfigExt, + physical_optimizer::chunk_extraction::extract_chunks, + provider::{chunks_to_physical_nodes, group_potential_duplicates, DeduplicateExec}, +}; + +/// Split de-duplication operations based on time. +/// +/// Chunks that overlap will be part of the same de-dup group. +/// +/// This should usually be more cost-efficient. +#[derive(Debug, Default)] +pub struct TimeSplit; + +impl PhysicalOptimizerRule for TimeSplit { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + plan.transform_up(&|plan| { + let plan_any = plan.as_any(); + + if let Some(dedup_exec) = plan_any.downcast_ref::() { + let mut children = dedup_exec.children(); + assert_eq!(children.len(), 1); + let child = children.remove(0); + let Some((schema, chunks, output_sort_key)) = extract_chunks(child.as_ref()) else { + return Ok(Transformed::No(plan)); + }; + + let groups = group_potential_duplicates(chunks); + + // if there are no chunks or there is only one group, we don't need to split + if groups.len() < 2 { + return Ok(Transformed::No(plan)); + } + + // Protect against degenerative plans + let max_dedup_time_split = config + .extensions + .get::() + .cloned() + .unwrap_or_default() + .max_dedup_time_split; + if groups.len() > max_dedup_time_split { + warn!( + n_groups = groups.len(), + max_dedup_time_split, + "cannot split dedup operation based on time overlaps, too many groups" + ); + return Ok(Transformed::No(plan)); + } + + let out = UnionExec::new( + groups + .into_iter() + .map(|chunks| { + Arc::new(DeduplicateExec::new( + chunks_to_physical_nodes( + &schema, + output_sort_key.as_ref(), + chunks, + config.execution.target_partitions, + ), + dedup_exec.sort_keys().to_vec(), + dedup_exec.use_chunk_order_col(), + )) as _ + }) + .collect(), + ); + return Ok(Transformed::Yes(Arc::new(out))); + } + + Ok(Transformed::No(plan)) + }) + } + + fn name(&self) -> &str { + "time_split" + } + + fn schema_check(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use crate::{ + physical_optimizer::{ + dedup::test_util::{chunk, dedup_plan}, + test_util::OptimizationTest, + }, + QueryChunk, + }; + + use super::*; + + #[test] + fn test_no_chunks() { + let schema = chunk(1).schema().clone(); + let plan = dedup_plan(schema, vec![]); + let opt = TimeSplit; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " EmptyExec" + output: + Ok: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " EmptyExec" + "### + ); + } + + #[test] + fn test_all_overlap() { + let chunk1 = chunk(1).with_timestamp_min_max(5, 10); + let chunk2 = chunk(2).with_timestamp_min_max(3, 5); + let chunk3 = chunk(3) + .with_dummy_parquet_file() + .with_timestamp_min_max(8, 9); + let schema = chunk1.schema().clone(); + let plan = dedup_plan(schema, vec![chunk1, chunk2, chunk3]); + let opt = TimeSplit; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[field, tag1, tag2, time]" + - " ParquetExec: file_groups={1 group: [[3.parquet]]}, projection=[field, tag1, tag2, time]" + output: + Ok: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[field, tag1, tag2, time]" + - " ParquetExec: file_groups={1 group: [[3.parquet]]}, projection=[field, tag1, tag2, time]" + "### + ); + } + + #[test] + fn test_different_groups() { + let chunk1 = chunk(1).with_timestamp_min_max(0, 10); + let chunk2 = chunk(2).with_timestamp_min_max(11, 12); + // use at least 3 parquet files for one of the two partitions to validate that `target_partitions` is forwarded correctly + let chunk3 = chunk(3) + .with_dummy_parquet_file() + .with_timestamp_min_max(1, 5); + let chunk4 = chunk(4) + .with_dummy_parquet_file() + .with_timestamp_min_max(11, 11); + let chunk5 = chunk(5) + .with_dummy_parquet_file() + .with_timestamp_min_max(7, 8); + let chunk6 = chunk(6) + .with_dummy_parquet_file() + .with_timestamp_min_max(0, 0); + let schema = chunk1.schema().clone(); + let plan = dedup_plan(schema, vec![chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]); + let opt = TimeSplit; + let mut config = ConfigOptions::default(); + config.execution.target_partitions = 2; + insta::assert_yaml_snapshot!( + OptimizationTest::new_with_config(plan, opt, &config), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[field, tag1, tag2, time]" + - " ParquetExec: file_groups={2 groups: [[3.parquet, 5.parquet], [4.parquet, 6.parquet]]}, projection=[field, tag1, tag2, time]" + output: + Ok: + - " UnionExec" + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time]" + - " ParquetExec: file_groups={2 groups: [[6.parquet, 5.parquet], [3.parquet]]}, projection=[field, tag1, tag2, time]" + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time]" + - " ParquetExec: file_groups={1 group: [[4.parquet]]}, projection=[field, tag1, tag2, time]" + "### + ); + } + + #[test] + fn test_max_split() { + let chunk1 = chunk(1).with_timestamp_min_max(1, 1); + let chunk2 = chunk(2).with_timestamp_min_max(2, 2); + let chunk3 = chunk(3).with_timestamp_min_max(3, 3); + let schema = chunk1.schema().clone(); + let plan = dedup_plan(schema, vec![chunk1, chunk2, chunk3]); + let opt = TimeSplit; + let mut config = ConfigOptions::default(); + config.extensions.insert(IoxConfigExt { + max_dedup_time_split: 2, + ..Default::default() + }); + insta::assert_yaml_snapshot!( + OptimizationTest::new_with_config(plan, opt, &config), + @r###" + --- + input: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=3, projection=[field, tag1, tag2, time]" + output: + Ok: + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=3, projection=[field, tag1, tag2, time]" + "### + ); + } +} diff --git a/iox_query/src/physical_optimizer/mod.rs b/iox_query/src/physical_optimizer/mod.rs new file mode 100644 index 0000000..a0bf7a4 --- /dev/null +++ b/iox_query/src/physical_optimizer/mod.rs @@ -0,0 +1,56 @@ +use std::sync::Arc; + +use datafusion::{execution::context::SessionState, physical_optimizer::PhysicalOptimizerRule}; + +use self::{ + combine_chunks::CombineChunks, + dedup::{ + dedup_null_columns::DedupNullColumns, dedup_sort_order::DedupSortOrder, + partition_split::PartitionSplit, remove_dedup::RemoveDedup, time_split::TimeSplit, + }, + predicate_pushdown::PredicatePushdown, + projection_pushdown::ProjectionPushdown, + sort::{order_union_sorted_inputs::OrderUnionSortedInputs, parquet_sortness::ParquetSortness}, + union::{nested_union::NestedUnion, one_union::OneUnion}, +}; + +mod chunk_extraction; +mod combine_chunks; +mod dedup; +mod predicate_pushdown; +mod projection_pushdown; +mod sort; +mod union; + +#[cfg(test)] +mod test_util; + +#[cfg(test)] +mod tests; + +/// Register IOx-specific [`PhysicalOptimizerRule`]s with the SessionContext +pub fn register_iox_physical_optimizers(state: SessionState) -> SessionState { + // prepend IOx-specific rules to DataFusion builtins + // The optimizer rules have to be done in this order + let mut optimizers: Vec> = vec![ + Arc::new(PartitionSplit), + Arc::new(TimeSplit), + Arc::new(RemoveDedup), + Arc::new(CombineChunks), + Arc::new(DedupNullColumns), + Arc::new(DedupSortOrder), + Arc::new(PredicatePushdown), + Arc::new(ProjectionPushdown), + Arc::new(ParquetSortness) as _, + Arc::new(NestedUnion), + Arc::new(OneUnion), + ]; + + // Append DataFUsion physical rules to the IOx-specific rules + optimizers.append(&mut state.physical_optimizers().to_vec()); + + // Add a rule to optimize plan with limit + optimizers.push(Arc::new(OrderUnionSortedInputs)); + + state.with_physical_optimizer_rules(optimizers) +} diff --git a/iox_query/src/physical_optimizer/predicate_pushdown.rs b/iox_query/src/physical_optimizer/predicate_pushdown.rs new file mode 100644 index 0000000..ab8ccd4 --- /dev/null +++ b/iox_query/src/physical_optimizer/predicate_pushdown.rs @@ -0,0 +1,496 @@ +use std::{collections::HashSet, sync::Arc}; + +use datafusion::{ + common::tree_node::{RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter}, + config::ConfigOptions, + datasource::physical_plan::ParquetExec, + error::{DataFusionError, Result}, + logical_expr::Operator, + physical_expr::{split_conjunction, utils::collect_columns}, + physical_optimizer::PhysicalOptimizerRule, + physical_plan::{ + empty::EmptyExec, + expressions::{BinaryExpr, Column}, + filter::FilterExec, + union::UnionExec, + ExecutionPlan, PhysicalExpr, + }, +}; + +use crate::provider::DeduplicateExec; + +/// Push down predicates. +#[derive(Debug, Default)] +pub struct PredicatePushdown; + +impl PhysicalOptimizerRule for PredicatePushdown { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + plan.transform_down(&|plan| { + let plan_any = plan.as_any(); + + if let Some(filter_exec) = plan_any.downcast_ref::() { + let mut children = filter_exec.children(); + assert_eq!(children.len(), 1); + let child = children.remove(0); + + let child_any = child.as_any(); + if child_any.downcast_ref::().is_some() { + return Ok(Transformed::Yes(child)); + } else if let Some(child_union) = child_any.downcast_ref::() { + let new_inputs = child_union + .inputs() + .iter() + .map(|input| { + FilterExec::try_new( + Arc::clone(filter_exec.predicate()), + Arc::clone(input), + ) + .map(|p| Arc::new(p) as Arc) + }) + .collect::>>()?; + let new_union = UnionExec::new(new_inputs); + return Ok(Transformed::Yes(Arc::new(new_union))); + } else if let Some(child_parquet) = child_any.downcast_ref::() { + let existing = child_parquet + .predicate() + .map(split_conjunction) + .unwrap_or_default(); + let both = conjunction( + existing + .into_iter() + .chain(split_conjunction(filter_exec.predicate())) + .cloned(), + ); + + let new_node = Arc::new(FilterExec::try_new( + Arc::clone(filter_exec.predicate()), + Arc::new(ParquetExec::new( + child_parquet.base_config().clone(), + both, + None, + )), + )?); + return Ok(Transformed::Yes(new_node)); + } else if let Some(child_dedup) = child_any.downcast_ref::() { + let dedup_cols = child_dedup.sort_columns(); + let (pushdown, no_pushdown): (Vec<_>, Vec<_>) = + split_conjunction(filter_exec.predicate()) + .into_iter() + .cloned() + .partition(|expr| { + collect_columns(expr) + .into_iter() + .all(|c| dedup_cols.contains(c.name())) + }); + + if !pushdown.is_empty() { + let mut grandchildren = child_dedup.children(); + assert_eq!(grandchildren.len(), 1); + let grandchild = grandchildren.remove(0); + + let mut new_node: Arc = Arc::new(DeduplicateExec::new( + Arc::new(FilterExec::try_new( + conjunction(pushdown).expect("not empty"), + grandchild, + )?), + child_dedup.sort_keys().to_vec(), + child_dedup.use_chunk_order_col(), + )); + if !no_pushdown.is_empty() { + new_node = Arc::new(FilterExec::try_new( + conjunction(no_pushdown).expect("not empty"), + new_node, + )?); + } + return Ok(Transformed::Yes(new_node)); + } + } + } + + Ok(Transformed::No(plan)) + }) + } + + fn name(&self) -> &str { + "predicate_pushdown" + } + + fn schema_check(&self) -> bool { + true + } +} + +#[derive(Debug, Default)] +struct ColumnCollector { + cols: HashSet, +} + +impl TreeNodeRewriter for ColumnCollector { + type N = Arc; + + fn pre_visit( + &mut self, + node: &Arc, + ) -> Result { + if let Some(column) = node.as_any().downcast_ref::() { + self.cols.insert(column.clone()); + } + Ok(RewriteRecursion::Continue) + } + + fn mutate( + &mut self, + expr: Arc, + ) -> Result, DataFusionError> { + Ok(expr) + } +} + +fn conjunction( + parts: impl IntoIterator>, +) -> Option> { + parts + .into_iter() + .reduce(|lhs, rhs| Arc::new(BinaryExpr::new(lhs, Operator::And, rhs))) +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion::{ + datasource::object_store::ObjectStoreUrl, + datasource::physical_plan::FileScanConfig, + logical_expr::Operator, + physical_expr::PhysicalSortExpr, + physical_plan::{ + expressions::{BinaryExpr, Column, Literal}, + placeholder_row::PlaceholderRowExec, + PhysicalExpr, Statistics, + }, + scalar::ScalarValue, + }; + use schema::sort::SortKeyBuilder; + + use crate::{physical_optimizer::test_util::OptimizationTest, util::arrow_sort_key_exprs}; + + use super::*; + + #[test] + fn test_empty_no_rows() { + let schema = schema(); + let plan = Arc::new( + FilterExec::try_new(predicate_tag(&schema), Arc::new(EmptyExec::new(schema))).unwrap(), + ); + let opt = PredicatePushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " FilterExec: tag1@0 = foo" + - " EmptyExec" + output: + Ok: + - " EmptyExec" + "### + ); + } + + #[test] + fn test_empty_with_rows() { + let schema = schema(); + let plan = Arc::new( + FilterExec::try_new( + predicate_tag(&schema), + Arc::new(PlaceholderRowExec::new(schema)), + ) + .unwrap(), + ); + let opt = PredicatePushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " FilterExec: tag1@0 = foo" + - " PlaceholderRowExec" + output: + Ok: + - " FilterExec: tag1@0 = foo" + - " PlaceholderRowExec" + "### + ); + } + + #[test] + fn test_union() { + let schema = schema(); + let plan = Arc::new( + FilterExec::try_new( + predicate_tag(&schema), + Arc::new(UnionExec::new( + (0..2) + .map(|_| Arc::new(PlaceholderRowExec::new(Arc::clone(&schema))) as _) + .collect(), + )), + ) + .unwrap(), + ); + let opt = PredicatePushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " FilterExec: tag1@0 = foo" + - " UnionExec" + - " PlaceholderRowExec" + - " PlaceholderRowExec" + output: + Ok: + - " UnionExec" + - " FilterExec: tag1@0 = foo" + - " PlaceholderRowExec" + - " FilterExec: tag1@0 = foo" + - " PlaceholderRowExec" + "### + ); + } + + #[test] + fn test_union_nested() { + let schema = schema(); + let plan = Arc::new( + FilterExec::try_new( + predicate_tag(&schema), + Arc::new(UnionExec::new(vec![Arc::new(UnionExec::new( + (0..2) + .map(|_| Arc::new(PlaceholderRowExec::new(Arc::clone(&schema))) as _) + .collect(), + ))])), + ) + .unwrap(), + ); + let opt = PredicatePushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " FilterExec: tag1@0 = foo" + - " UnionExec" + - " UnionExec" + - " PlaceholderRowExec" + - " PlaceholderRowExec" + output: + Ok: + - " UnionExec" + - " UnionExec" + - " FilterExec: tag1@0 = foo" + - " PlaceholderRowExec" + - " FilterExec: tag1@0 = foo" + - " PlaceholderRowExec" + "### + ); + } + + #[test] + fn test_parquet() { + let schema = schema(); + let base_config = FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test://").unwrap(), + file_schema: Arc::clone(&schema), + file_groups: vec![], + statistics: Statistics::new_unknown(&schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![], + }; + let plan = Arc::new( + FilterExec::try_new( + predicate_mixed(&schema), + Arc::new(ParquetExec::new( + base_config, + Some(predicate_tag(&schema)), + None, + )), + ) + .unwrap(), + ); + let opt = PredicatePushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " FilterExec: tag1@0 = field@2" + - " ParquetExec: file_groups={0 groups: []}, projection=[tag1, tag2, field], predicate=tag1@0 = foo, pruning_predicate=tag1_min@0 <= foo AND foo <= tag1_max@1" + output: + Ok: + - " FilterExec: tag1@0 = field@2" + - " ParquetExec: file_groups={0 groups: []}, projection=[tag1, tag2, field], predicate=tag1@0 = foo AND tag1@0 = field@2, pruning_predicate=tag1_min@0 <= foo AND foo <= tag1_max@1" + "### + ); + } + + #[test] + fn test_dedup_no_pushdown() { + let schema = schema(); + let plan = Arc::new( + FilterExec::try_new( + predicate_field(&schema), + Arc::new(DeduplicateExec::new( + Arc::new(PlaceholderRowExec::new(Arc::clone(&schema))), + sort_expr(&schema), + false, + )), + ) + .unwrap(), + ); + let opt = PredicatePushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " FilterExec: field@2 = val" + - " DeduplicateExec: [tag1@0 ASC,tag2@1 ASC]" + - " PlaceholderRowExec" + output: + Ok: + - " FilterExec: field@2 = val" + - " DeduplicateExec: [tag1@0 ASC,tag2@1 ASC]" + - " PlaceholderRowExec" + "### + ); + } + + #[test] + fn test_dedup_all_pushdown() { + let schema = schema(); + let plan = Arc::new( + FilterExec::try_new( + predicate_tag(&schema), + Arc::new(DeduplicateExec::new( + Arc::new(PlaceholderRowExec::new(Arc::clone(&schema))), + sort_expr(&schema), + false, + )), + ) + .unwrap(), + ); + let opt = PredicatePushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " FilterExec: tag1@0 = foo" + - " DeduplicateExec: [tag1@0 ASC,tag2@1 ASC]" + - " PlaceholderRowExec" + output: + Ok: + - " DeduplicateExec: [tag1@0 ASC,tag2@1 ASC]" + - " FilterExec: tag1@0 = foo" + - " PlaceholderRowExec" + "### + ); + } + + #[test] + fn test_dedup_mixed() { + let schema = schema(); + let plan = Arc::new( + FilterExec::try_new( + conjunction([ + predicate_tag(&schema), + predicate_tags(&schema), + predicate_field(&schema), + predicate_mixed(&schema), + predicate_other(), + ]) + .expect("not empty"), + Arc::new(DeduplicateExec::new( + Arc::new(PlaceholderRowExec::new(Arc::clone(&schema))), + sort_expr(&schema), + false, + )), + ) + .unwrap(), + ); + let opt = PredicatePushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " FilterExec: tag1@0 = foo AND tag1@0 = tag2@1 AND field@2 = val AND tag1@0 = field@2 AND true" + - " DeduplicateExec: [tag1@0 ASC,tag2@1 ASC]" + - " PlaceholderRowExec" + output: + Ok: + - " FilterExec: field@2 = val AND tag1@0 = field@2" + - " DeduplicateExec: [tag1@0 ASC,tag2@1 ASC]" + - " FilterExec: tag1@0 = foo AND tag1@0 = tag2@1 AND true" + - " PlaceholderRowExec" + "### + ); + } + + fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("tag1", DataType::Utf8, true), + Field::new("tag2", DataType::Utf8, true), + Field::new("field", DataType::UInt8, true), + ])) + } + + fn sort_expr(schema: &SchemaRef) -> Vec { + let sort_key = SortKeyBuilder::new() + .with_col("tag1") + .with_col("tag2") + .build(); + arrow_sort_key_exprs(&sort_key, schema) + } + + fn predicate_tag(schema: &SchemaRef) -> Arc { + Arc::new(BinaryExpr::new( + Arc::new(Column::new_with_schema("tag1", schema).unwrap()), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::from("foo"))), + )) + } + + fn predicate_tags(schema: &SchemaRef) -> Arc { + Arc::new(BinaryExpr::new( + Arc::new(Column::new_with_schema("tag1", schema).unwrap()), + Operator::Eq, + Arc::new(Column::new_with_schema("tag2", schema).unwrap()), + )) + } + + fn predicate_field(schema: &SchemaRef) -> Arc { + Arc::new(BinaryExpr::new( + Arc::new(Column::new_with_schema("field", schema).unwrap()), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::from("val"))), + )) + } + + fn predicate_mixed(schema: &SchemaRef) -> Arc { + Arc::new(BinaryExpr::new( + Arc::new(Column::new_with_schema("tag1", schema).unwrap()), + Operator::Eq, + Arc::new(Column::new_with_schema("field", schema).unwrap()), + )) + } + + fn predicate_other() -> Arc { + Arc::new(Literal::new(ScalarValue::from(true))) + } +} diff --git a/iox_query/src/physical_optimizer/projection_pushdown.rs b/iox_query/src/physical_optimizer/projection_pushdown.rs new file mode 100644 index 0000000..0efe597 --- /dev/null +++ b/iox_query/src/physical_optimizer/projection_pushdown.rs @@ -0,0 +1,1718 @@ +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + +use arrow::datatypes::SchemaRef; +use datafusion::{ + common::tree_node::{Transformed, TreeNode}, + config::ConfigOptions, + datasource::physical_plan::{FileScanConfig, ParquetExec}, + error::{DataFusionError, Result}, + physical_expr::{ + utils::{collect_columns, reassign_predicate_columns}, + PhysicalSortExpr, + }, + physical_optimizer::PhysicalOptimizerRule, + physical_plan::{ + empty::EmptyExec, + expressions::Column, + filter::FilterExec, + placeholder_row::PlaceholderRowExec, + projection::ProjectionExec, + sorts::{sort::SortExec, sort_preserving_merge::SortPreservingMergeExec}, + union::UnionExec, + ExecutionPlan, PhysicalExpr, + }, +}; + +use crate::provider::{DeduplicateExec, RecordBatchesExec}; + +/// Push down projections. +#[derive(Debug, Default)] +pub struct ProjectionPushdown; + +impl PhysicalOptimizerRule for ProjectionPushdown { + #[allow(clippy::only_used_in_recursion)] + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + plan.transform_down(&|plan| { + let plan_any = plan.as_any(); + + if let Some(projection_exec) = plan_any.downcast_ref::() { + let child = projection_exec.input(); + + let mut column_indices = Vec::with_capacity(projection_exec.expr().len()); + let mut column_names = Vec::with_capacity(projection_exec.expr().len()); + for (expr, output_name) in projection_exec.expr() { + if let Some(column) = expr.as_any().downcast_ref::() { + if column.name() == output_name { + column_indices.push(column.index()); + column_names.push(output_name.as_str()); + } else { + // don't bother w/ renames + return Ok(Transformed::No(plan)); + } + } else { + // don't bother to deal w/ calculation within projection nodes + return Ok(Transformed::No(plan)); + } + } + + let child_any = child.as_any(); + if let Some(child_empty) = child_any.downcast_ref::() { + let new_child = + EmptyExec::new(Arc::new(child_empty.schema().project(&column_indices)?)); + return Ok(Transformed::Yes(Arc::new(new_child))); + } else if let Some(child_placeholder) = + child_any.downcast_ref::() + { + let new_child = PlaceholderRowExec::new(Arc::new( + child_placeholder.schema().project(&column_indices)?, + )); + return Ok(Transformed::Yes(Arc::new(new_child))); + } else if let Some(child_union) = child_any.downcast_ref::() { + let new_inputs = child_union + .inputs() + .iter() + .map(|input| { + let exec = ProjectionExec::try_new( + projection_exec.expr().to_vec(), + Arc::clone(input), + )?; + Ok(Arc::new(exec) as _) + }) + .collect::>>()?; + let new_union = UnionExec::new(new_inputs); + return Ok(Transformed::Yes(Arc::new(new_union))); + } else if let Some(child_parquet) = child_any.downcast_ref::() { + let projection = match child_parquet.base_config().projection.as_ref() { + Some(projection) => column_indices + .into_iter() + .map(|idx| { + projection.get(idx).copied().ok_or_else(|| { + DataFusionError::Execution("Projection broken".to_string()) + }) + }) + .collect::>>()?, + None => column_indices, + }; + let output_ordering = child_parquet + .base_config() + .output_ordering + .iter() + .map(|output_ordering| { + project_output_ordering(output_ordering, projection_exec.schema()) + }) + .collect::>()?; + let base_config = FileScanConfig { + projection: Some(projection), + output_ordering, + ..child_parquet.base_config().clone() + }; + let new_child = + ParquetExec::new(base_config, child_parquet.predicate().cloned(), None); + return Ok(Transformed::Yes(Arc::new(new_child))); + } else if let Some(child_filter) = child_any.downcast_ref::() { + let filter_required_cols = collect_columns(child_filter.predicate()); + let filter_required_cols = filter_required_cols + .iter() + .map(|col| col.name()) + .collect::>(); + + let plan = wrap_user_into_projections( + &filter_required_cols, + &column_names, + Arc::clone(child_filter.input()), + |plan| { + Ok(Arc::new(FilterExec::try_new( + reassign_predicate_columns( + Arc::clone(child_filter.predicate()), + &plan.schema(), + false, + )?, + plan, + )?)) + }, + )?; + + return Ok(Transformed::Yes(plan)); + } else if let Some(child_sort) = child_any.downcast_ref::() { + let sort_required_cols = child_sort + .expr() + .iter() + .map(|expr| collect_columns(&expr.expr)) + .collect::>(); + let sort_required_cols = sort_required_cols + .iter() + .flat_map(|cols| cols.iter()) + .map(|col| col.name()) + .collect::>(); + + let plan = wrap_user_into_projections( + &sort_required_cols, + &column_names, + Arc::clone(child_sort.input()), + |plan| { + Ok(Arc::new( + SortExec::new( + reassign_sort_exprs_columns(child_sort.expr(), &plan.schema())?, + plan, + ) + .with_preserve_partitioning(child_sort.preserve_partitioning()) + .with_fetch(child_sort.fetch()), + )) + }, + )?; + + return Ok(Transformed::Yes(plan)); + } else if let Some(child_sort) = child_any.downcast_ref::() + { + let sort_required_cols = child_sort + .expr() + .iter() + .map(|expr| collect_columns(&expr.expr)) + .collect::>(); + let sort_required_cols = sort_required_cols + .iter() + .flat_map(|cols| cols.iter()) + .map(|col| col.name()) + .collect::>(); + + let plan = wrap_user_into_projections( + &sort_required_cols, + &column_names, + Arc::clone(child_sort.input()), + |plan| { + Ok(Arc::new(SortPreservingMergeExec::new( + reassign_sort_exprs_columns(child_sort.expr(), &plan.schema())?, + plan, + ))) + }, + )?; + + return Ok(Transformed::Yes(plan)); + } else if let Some(child_proj) = child_any.downcast_ref::() { + let expr = column_indices + .iter() + .map(|idx| child_proj.expr()[*idx].clone()) + .collect(); + let plan = Arc::new(ProjectionExec::try_new( + expr, + Arc::clone(child_proj.input()), + )?); + + // need to call `optimize` directly on the plan, because otherwise we would continue with the child + // and miss the optimization of that particular new ProjectionExec + let plan = self.optimize(plan, config)?; + + return Ok(Transformed::Yes(plan)); + } else if let Some(child_dedup) = child_any.downcast_ref::() { + let dedup_required_cols = child_dedup.sort_columns(); + + let mut children = child_dedup.children(); + assert_eq!(children.len(), 1); + let input = children.pop().expect("just checked len"); + + let plan = wrap_user_into_projections( + &dedup_required_cols, + &column_names, + input, + |plan| { + let sort_keys = reassign_sort_exprs_columns( + child_dedup.sort_keys(), + &plan.schema(), + )?; + Ok(Arc::new(DeduplicateExec::new( + plan, + sort_keys, + child_dedup.use_chunk_order_col(), + ))) + }, + )?; + + return Ok(Transformed::Yes(plan)); + } else if let Some(child_recordbatches) = + child_any.downcast_ref::() + { + let new_child = RecordBatchesExec::new( + child_recordbatches.chunks().cloned(), + Arc::new(child_recordbatches.schema().project(&column_indices)?), + child_recordbatches.output_sort_key_memo().cloned(), + ); + return Ok(Transformed::Yes(Arc::new(new_child))); + } + } + + Ok(Transformed::No(plan)) + }) + } + + fn name(&self) -> &str { + "projection_pushdown" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// Given the output ordering and a projected schema, returns the +/// largest prefix of the ordering that is in the projection +/// +/// For example, +/// +/// ```text +/// output_ordering: a, b, c +/// projection: a, c +/// returns --> a +/// ``` +/// +/// To see why the input has to be a prefix, consider this input: +/// +/// ```text +/// a b +/// 1 1 +/// 2 2 +/// 3 1 +/// `` +/// +/// It is sorted on `a,b` but *not* sorted on `b` +fn project_output_ordering( + output_ordering: &[PhysicalSortExpr], + projected_schema: SchemaRef, +) -> Result> { + // filter out sort exprs columns that got projected away + let known_columns = projected_schema + .all_fields() + .iter() + .map(|f| f.name().as_str()) + .collect::>(); + + // take longest prefix + let sort_exprs = output_ordering + .iter() + .take_while(|expr| { + if let Some(col) = expr.expr.as_any().downcast_ref::() { + known_columns.contains(col.name()) + } else { + // do not keep exprs like `a+1` or `-a` as they may + // not maintain ordering + false + } + }) + .cloned() + .collect::>(); + + reassign_sort_exprs_columns(&sort_exprs, &projected_schema) +} + +fn schema_name_projection( + schema: &SchemaRef, + cols: &[&str], +) -> Result, String)>> { + let idx_lookup = schema + .fields() + .iter() + .enumerate() + .map(|(idx, field)| (field.name().as_str(), idx)) + .collect::>(); + + cols.iter() + .map(|col| { + let idx = *idx_lookup.get(col).ok_or_else(|| { + DataFusionError::Execution(format!("Cannot find column to project: {col}")) + })?; + + let expr = Arc::new(Column::new(col, idx)) as _; + Ok((expr, (*col).to_owned())) + }) + .collect::>>() +} + +/// Wraps an intermediate node (like [`FilterExec`]) that has a single input but also uses some columns itself into +/// appropriate projections. +/// +/// This will turn: +/// +/// ```yaml +/// --- +/// projection: +/// user: # e.g. FilterExec +/// inner: +/// ``` +/// +/// into +/// +/// ```yaml +/// --- +/// projection: # if `user` outputs too many cols +/// user: +/// projection: # if `inner` outputs too many cols +/// inner: +/// ``` +fn wrap_user_into_projections( + user_required_cols: &HashSet<&str>, + outer_cols: &[&str], + inner_plan: Arc, + user_constructor: F, +) -> Result> +where + F: FnOnce(Arc) -> Result>, +{ + let mut plan = inner_plan; + + let inner_required_cols = user_required_cols + .iter() + .chain(outer_cols.iter()) + .copied() + .collect::>(); + + // sort inner required cols according the final projection + let outer_cols_order = outer_cols + .iter() + .copied() + .enumerate() + .map(|(idx, col)| (col, idx)) + .collect::>(); + let mut inner_projection_cols = inner_required_cols + .iter() + .copied() + .map(|col| { + // Note: if the col is NOT known, this will fail in `schema_name_projection`, so we just default it here + let idx = outer_cols_order.get(col).copied().unwrap_or_default(); + (idx, col) + }) + .collect::>(); + inner_projection_cols.sort(); + let inner_projection_cols = inner_projection_cols + .into_iter() + .map(|(_idx, col)| col) + .collect::>(); + + let plan_schema = plan.schema(); + let plan_cols = plan_schema + .fields() + .iter() + .map(|f| f.name().as_str()) + .collect::>(); + if plan_cols != inner_projection_cols { + let expr = schema_name_projection(&plan.schema(), &inner_projection_cols)?; + plan = Arc::new(ProjectionExec::try_new(expr, plan)?); + } + + plan = user_constructor(plan)?; + + if outer_cols.len() < plan.schema().fields().len() { + let expr = schema_name_projection(&plan.schema(), outer_cols)?; + plan = Arc::new(ProjectionExec::try_new(expr, plan)?); + } + + Ok(plan) +} + +fn reassign_sort_exprs_columns( + sort_exprs: &[PhysicalSortExpr], + schema: &SchemaRef, +) -> Result> { + sort_exprs + .iter() + .map(|expr| { + Ok(PhysicalSortExpr { + expr: reassign_predicate_columns(Arc::clone(&expr.expr), schema, false)?, + options: expr.options, + }) + }) + .collect() +} + +#[cfg(test)] +mod tests { + use arrow::{ + compute::SortOptions, + datatypes::{DataType, Field, Fields, Schema, SchemaRef}, + }; + use datafusion::{ + datasource::object_store::ObjectStoreUrl, + logical_expr::Operator, + physical_plan::{ + expressions::{BinaryExpr, Literal}, + DisplayAs, PhysicalExpr, Statistics, + }, + scalar::ScalarValue, + }; + use serde::Serialize; + + use crate::{ + physical_optimizer::test_util::{assert_unknown_partitioning, OptimizationTest}, + test::TestChunk, + }; + + use super::*; + + #[test] + fn test_empty_pushdown_select() { + let schema = schema(); + let plan = Arc::new( + ProjectionExec::try_new( + vec![(expr_col("tag1", &schema), String::from("tag1"))], + Arc::new(EmptyExec::new(schema)), + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + let test = OptimizationTest::new(plan, opt); + insta::assert_yaml_snapshot!( + test, + @r###" + --- + input: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " EmptyExec" + output: + Ok: + - " EmptyExec" + "### + ); + + let empty_exec = test + .output_plan() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let expected_schema = Schema::new(vec![Field::new("tag1", DataType::Utf8, true)]); + assert_eq!(empty_exec.schema().as_ref(), &expected_schema); + } + + #[test] + fn test_empty_pushdown_reorder() { + let schema = schema(); + let plan = Arc::new( + ProjectionExec::try_new( + vec![ + (expr_col("tag2", &schema), String::from("tag2")), + (expr_col("tag1", &schema), String::from("tag1")), + (expr_col("field", &schema), String::from("field")), + ], + Arc::new(EmptyExec::new(schema)), + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + let test = OptimizationTest::new(plan, opt); + insta::assert_yaml_snapshot!( + test, + @r###" + --- + input: + - " ProjectionExec: expr=[tag2@1 as tag2, tag1@0 as tag1, field@2 as field]" + - " EmptyExec" + output: + Ok: + - " EmptyExec" + "### + ); + + let empty_exec = test + .output_plan() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let expected_schema = Schema::new(vec![ + Field::new("tag2", DataType::Utf8, true), + Field::new("tag1", DataType::Utf8, true), + Field::new("field", DataType::UInt64, true), + ]); + assert_eq!(empty_exec.schema().as_ref(), &expected_schema); + } + + #[test] + fn test_ignore_when_only_impure_projection_rename() { + let schema = schema(); + let plan = Arc::new( + ProjectionExec::try_new( + vec![(expr_col("tag2", &schema), String::from("tag1"))], + Arc::new(EmptyExec::new(schema)), + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " ProjectionExec: expr=[tag2@1 as tag1]" + - " EmptyExec" + output: + Ok: + - " ProjectionExec: expr=[tag2@1 as tag1]" + - " EmptyExec" + "### + ); + } + + #[test] + fn test_ignore_when_partial_impure_projection_rename() { + let schema = schema(); + let plan = Arc::new( + ProjectionExec::try_new( + vec![ + (expr_col("tag1", &schema), String::from("tag1")), + (expr_col("tag2", &schema), String::from("tag3")), + ], + Arc::new(EmptyExec::new(schema)), + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " ProjectionExec: expr=[tag1@0 as tag1, tag2@1 as tag3]" + - " EmptyExec" + output: + Ok: + - " ProjectionExec: expr=[tag1@0 as tag1, tag2@1 as tag3]" + - " EmptyExec" + "### + ); + } + + #[test] + fn test_ignore_impure_projection_calc() { + let schema = schema(); + let plan = Arc::new( + ProjectionExec::try_new( + vec![( + Arc::new(Literal::new(ScalarValue::from("foo"))), + String::from("tag1"), + )], + Arc::new(EmptyExec::new(schema)), + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " ProjectionExec: expr=[foo as tag1]" + - " EmptyExec" + output: + Ok: + - " ProjectionExec: expr=[foo as tag1]" + - " EmptyExec" + "### + ); + } + + #[test] + fn test_unknown_node_type() { + let schema = schema(); + let plan = Arc::new( + ProjectionExec::try_new( + vec![(expr_col("tag1", &schema), String::from("tag1"))], + Arc::new(TestExec::new(schema)), + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " Test" + output: + Ok: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " Test" + "### + ); + } + + #[test] + fn test_union() { + let schema = schema(); + let plan = Arc::new( + ProjectionExec::try_new( + vec![(expr_col("tag1", &schema), String::from("tag1"))], + Arc::new(UnionExec::new(vec![ + Arc::new(TestExec::new(Arc::clone(&schema))), + Arc::new(TestExec::new(schema)), + ])), + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " UnionExec" + - " Test" + - " Test" + output: + Ok: + - " UnionExec" + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " Test" + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " Test" + "### + ); + } + + #[test] + fn test_nested_union() { + let schema = schema(); + let plan = Arc::new( + ProjectionExec::try_new( + vec![(expr_col("tag1", &schema), String::from("tag1"))], + Arc::new(UnionExec::new(vec![ + Arc::new(UnionExec::new(vec![ + Arc::new(TestExec::new(Arc::clone(&schema))), + Arc::new(TestExec::new(Arc::clone(&schema))), + ])), + Arc::new(TestExec::new(schema)), + ])), + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " UnionExec" + - " UnionExec" + - " Test" + - " Test" + - " Test" + output: + Ok: + - " UnionExec" + - " UnionExec" + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " Test" + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " Test" + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " Test" + "### + ); + } + + #[test] + fn test_parquet() { + let schema = Arc::new(Schema::new(vec![ + Field::new("tag1", DataType::Utf8, true), + Field::new("tag2", DataType::Utf8, true), + Field::new("tag3", DataType::Utf8, true), + Field::new("field", DataType::UInt64, true), + ])); + let projection = vec![3, 2, 1]; + let schema_projected = Arc::new(schema.project(&projection).unwrap()); + let base_config = FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test://").unwrap(), + file_schema: Arc::clone(&schema), + file_groups: vec![], + statistics: Statistics::new_unknown(&schema), + projection: Some(projection), + limit: None, + table_partition_cols: vec![], + output_ordering: vec![vec![ + PhysicalSortExpr { + expr: expr_col("tag3", &schema_projected), + options: Default::default(), + }, + PhysicalSortExpr { + expr: expr_col("field", &schema_projected), + options: Default::default(), + }, + PhysicalSortExpr { + expr: expr_col("tag2", &schema_projected), + options: Default::default(), + }, + ]], + }; + let inner = ParquetExec::new(base_config, Some(expr_string_cmp("tag1", &schema)), None); + let plan = Arc::new( + ProjectionExec::try_new( + vec![ + (expr_col("tag2", &inner.schema()), String::from("tag2")), + (expr_col("tag3", &inner.schema()), String::from("tag3")), + ], + Arc::new(inner), + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + let test = OptimizationTest::new(plan, opt); + insta::assert_yaml_snapshot!( + test, + @r###" + --- + input: + - " ProjectionExec: expr=[tag2@2 as tag2, tag3@1 as tag3]" + - " ParquetExec: file_groups={0 groups: []}, projection=[field, tag3, tag2], output_ordering=[tag3@1 ASC, field@0 ASC, tag2@2 ASC], predicate=tag1@0 = foo, pruning_predicate=tag1_min@0 <= foo AND foo <= tag1_max@1" + output: + Ok: + - " ParquetExec: file_groups={0 groups: []}, projection=[tag2, tag3], output_ordering=[tag3@1 ASC], predicate=tag1@0 = foo, pruning_predicate=tag1_min@0 <= foo AND foo <= tag1_max@1" + "### + ); + + let parquet_exec = test + .output_plan() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let expected_schema = Schema::new(vec![ + Field::new("tag2", DataType::Utf8, true), + Field::new("tag3", DataType::Utf8, true), + ]); + assert_eq!(parquet_exec.schema().as_ref(), &expected_schema); + } + + #[test] + fn test_filter_projection_split() { + let schema = schema(); + let plan = Arc::new( + ProjectionExec::try_new( + vec![(expr_col("tag1", &schema), String::from("tag1"))], + Arc::new( + FilterExec::try_new( + expr_string_cmp("tag2", &schema), + Arc::new(TestExec::new(schema)), + ) + .unwrap(), + ), + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " FilterExec: tag2@1 = foo" + - " Test" + output: + Ok: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " FilterExec: tag2@1 = foo" + - " ProjectionExec: expr=[tag1@0 as tag1, tag2@1 as tag2]" + - " Test" + "### + ); + } + + #[test] + fn test_filter_inner_does_not_need_projection() { + let schema = schema(); + let inner = TestExec::new(Arc::new(schema.project(&[0, 1]).unwrap())); + let plan = Arc::new( + ProjectionExec::try_new( + vec![(expr_col("tag1", &inner.schema()), String::from("tag1"))], + Arc::new( + FilterExec::try_new(expr_string_cmp("tag2", &inner.schema()), Arc::new(inner)) + .unwrap(), + ), + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " FilterExec: tag2@1 = foo" + - " Test" + output: + Ok: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " FilterExec: tag2@1 = foo" + - " Test" + "### + ); + } + + #[test] + fn test_filter_outer_does_not_need_projection() { + let schema = schema(); + let plan = Arc::new( + ProjectionExec::try_new( + vec![(expr_col("tag2", &schema), String::from("tag2"))], + Arc::new( + FilterExec::try_new( + expr_string_cmp("tag2", &schema), + Arc::new(TestExec::new(schema)), + ) + .unwrap(), + ), + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " ProjectionExec: expr=[tag2@1 as tag2]" + - " FilterExec: tag2@1 = foo" + - " Test" + output: + Ok: + - " FilterExec: tag2@0 = foo" + - " ProjectionExec: expr=[tag2@1 as tag2]" + - " Test" + "### + ); + } + + #[test] + fn test_filter_all_projections_unnecessary() { + let schema = schema(); + let inner = TestExec::new(Arc::new(schema.project(&[1]).unwrap())); + let plan = Arc::new( + ProjectionExec::try_new( + vec![(expr_col("tag2", &inner.schema()), String::from("tag2"))], + Arc::new( + FilterExec::try_new(expr_string_cmp("tag2", &inner.schema()), Arc::new(inner)) + .unwrap(), + ), + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " ProjectionExec: expr=[tag2@0 as tag2]" + - " FilterExec: tag2@0 = foo" + - " Test" + output: + Ok: + - " FilterExec: tag2@0 = foo" + - " Test" + "### + ); + } + + #[test] + fn test_filter_uses_resorted_cols() { + let schema = schema(); + let plan = Arc::new( + ProjectionExec::try_new( + vec![ + (expr_col("tag2", &schema), String::from("tag2")), + (expr_col("tag1", &schema), String::from("tag1")), + (expr_col("field", &schema), String::from("field")), + ], + Arc::new( + FilterExec::try_new( + expr_and( + expr_string_cmp("tag2", &schema), + expr_string_cmp("tag1", &schema), + ), + Arc::new(TestExec::new(schema)), + ) + .unwrap(), + ), + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " ProjectionExec: expr=[tag2@1 as tag2, tag1@0 as tag1, field@2 as field]" + - " FilterExec: tag2@1 = foo AND tag1@0 = foo" + - " Test" + output: + Ok: + - " FilterExec: tag2@0 = foo AND tag1@1 = foo" + - " ProjectionExec: expr=[tag2@1 as tag2, tag1@0 as tag1, field@2 as field]" + - " Test" + "### + ); + } + + // since `SortExec` and `FilterExec` both use `wrap_user_into_projections`, we only test a few variants for `SortExec` + #[test] + fn test_sort_projection_split() { + let schema = schema(); + let plan = Arc::new( + ProjectionExec::try_new( + vec![(expr_col("tag1", &schema), String::from("tag1"))], + Arc::new( + SortExec::new( + vec![PhysicalSortExpr { + expr: expr_col("tag2", &schema), + options: SortOptions { + descending: true, + ..Default::default() + }, + }], + Arc::new(TestExec::new(schema)), + ) + .with_fetch(Some(42)), + ), + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " SortExec: TopK(fetch=42), expr=[tag2@1 DESC]" + - " Test" + output: + Ok: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " SortExec: TopK(fetch=42), expr=[tag2@1 DESC]" + - " ProjectionExec: expr=[tag1@0 as tag1, tag2@1 as tag2]" + - " Test" + "### + ); + } + + #[test] + fn test_sort_preserve_partitioning() { + let schema = schema(); + let plan = Arc::new( + ProjectionExec::try_new( + vec![(expr_col("tag1", &schema), String::from("tag1"))], + Arc::new( + SortExec::new( + vec![PhysicalSortExpr { + expr: expr_col("tag2", &schema), + options: SortOptions { + descending: true, + ..Default::default() + }, + }], + Arc::new(TestExec::new_with_partitions(schema, 2)), + ) + .with_preserve_partitioning(true) + .with_fetch(Some(42)), + ), + ) + .unwrap(), + ); + + assert_unknown_partitioning(plan.output_partitioning(), 2); + + let opt = ProjectionPushdown; + let test = OptimizationTest::new(plan, opt); + insta::assert_yaml_snapshot!( + test, + @r###" + --- + input: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " SortExec: TopK(fetch=42), expr=[tag2@1 DESC]" + - " Test" + output: + Ok: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " SortExec: TopK(fetch=42), expr=[tag2@1 DESC]" + - " ProjectionExec: expr=[tag1@0 as tag1, tag2@1 as tag2]" + - " Test" + "### + ); + + assert_unknown_partitioning(test.output_plan().unwrap().output_partitioning(), 2); + } + + // since `SortPreservingMergeExec` and `FilterExec` both use `wrap_user_into_projections`, we only test one variant for `SortPreservingMergeExec` + #[test] + fn test_sortpreservingmerge_projection_split() { + let schema = schema(); + let plan = Arc::new( + ProjectionExec::try_new( + vec![(expr_col("tag1", &schema), String::from("tag1"))], + Arc::new(SortPreservingMergeExec::new( + vec![PhysicalSortExpr { + expr: expr_col("tag2", &schema), + options: SortOptions { + descending: true, + ..Default::default() + }, + }], + Arc::new(TestExec::new(schema)), + )), + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " SortPreservingMergeExec: [tag2@1 DESC]" + - " Test" + output: + Ok: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " SortPreservingMergeExec: [tag2@1 DESC]" + - " ProjectionExec: expr=[tag1@0 as tag1, tag2@1 as tag2]" + - " Test" + "### + ); + } + + #[test] + fn test_nested_proj_inner_is_impure() { + let schema = schema(); + let plan = Arc::new(EmptyExec::new(schema)); + let plan = Arc::new( + ProjectionExec::try_new( + vec![ + ( + Arc::new(Literal::new(ScalarValue::from("foo"))), + String::from("tag1"), + ), + ( + Arc::new(Literal::new(ScalarValue::from("bar"))), + String::from("tag2"), + ), + ], + plan, + ) + .unwrap(), + ); + let plan = Arc::new( + ProjectionExec::try_new( + vec![(expr_col("tag1", &plan.schema()), String::from("tag1"))], + plan, + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " ProjectionExec: expr=[foo as tag1, bar as tag2]" + - " EmptyExec" + output: + Ok: + - " ProjectionExec: expr=[foo as tag1]" + - " EmptyExec" + "### + ); + } + + #[test] + fn test_nested_proj_inner_is_pure() { + let schema = schema(); + let plan = Arc::new(EmptyExec::new(schema)); + let plan = Arc::new( + ProjectionExec::try_new( + vec![ + (expr_col("tag1", &plan.schema()), String::from("tag1")), + (expr_col("tag2", &plan.schema()), String::from("tag2")), + ], + plan, + ) + .unwrap(), + ); + let plan = Arc::new( + ProjectionExec::try_new( + vec![(expr_col("tag1", &plan.schema()), String::from("tag1"))], + plan, + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + let test = OptimizationTest::new(plan, opt); + insta::assert_yaml_snapshot!( + test, + @r###" + --- + input: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " ProjectionExec: expr=[tag1@0 as tag1, tag2@1 as tag2]" + - " EmptyExec" + output: + Ok: + - " EmptyExec" + "### + ); + let empty_exec = test + .output_plan() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let expected_schema = Schema::new(vec![Field::new("tag1", DataType::Utf8, true)]); + assert_eq!(empty_exec.schema().as_ref(), &expected_schema); + } + + // since `DeduplicateExec` and `FilterExec` both use `wrap_user_into_projections`, we only test a few variants for `DeduplicateExec` + #[test] + fn test_dedup_projection_split1() { + let schema = schema(); + let plan = Arc::new( + ProjectionExec::try_new( + vec![(expr_col("tag1", &schema), String::from("tag1"))], + Arc::new(DeduplicateExec::new( + Arc::new(TestExec::new(Arc::clone(&schema))), + vec![PhysicalSortExpr { + expr: expr_col("tag2", &schema), + options: SortOptions { + descending: true, + ..Default::default() + }, + }], + false, + )), + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " DeduplicateExec: [tag2@1 DESC]" + - " Test" + output: + Ok: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " DeduplicateExec: [tag2@1 DESC]" + - " ProjectionExec: expr=[tag1@0 as tag1, tag2@1 as tag2]" + - " Test" + "### + ); + } + + #[test] + fn test_dedup_projection_split2() { + let schema = Arc::new(Schema::new(vec![ + Field::new("tag1", DataType::Utf8, true), + Field::new("tag2", DataType::Utf8, true), + Field::new("field1", DataType::UInt64, true), + Field::new("field2", DataType::UInt64, true), + ])); + let plan = Arc::new( + ProjectionExec::try_new( + vec![ + (expr_col("tag1", &schema), String::from("tag1")), + (expr_col("field1", &schema), String::from("field1")), + ], + Arc::new(DeduplicateExec::new( + Arc::new(TestExec::new(Arc::clone(&schema))), + vec![ + PhysicalSortExpr { + expr: expr_col("tag1", &schema), + options: SortOptions { + descending: true, + ..Default::default() + }, + }, + PhysicalSortExpr { + expr: expr_col("tag2", &schema), + options: SortOptions { + descending: false, + ..Default::default() + }, + }, + ], + false, + )), + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " ProjectionExec: expr=[tag1@0 as tag1, field1@2 as field1]" + - " DeduplicateExec: [tag1@0 DESC,tag2@1 ASC]" + - " Test" + output: + Ok: + - " ProjectionExec: expr=[tag1@0 as tag1, field1@2 as field1]" + - " DeduplicateExec: [tag1@0 DESC,tag2@1 ASC]" + - " ProjectionExec: expr=[tag1@0 as tag1, tag2@1 as tag2, field1@2 as field1]" + - " Test" + "### + ); + } + + #[test] + fn test_recordbatches() { + let schema = schema(); + let chunk = TestChunk::new("table") + .with_tag_column("tag1") + .with_u64_column("field"); + let plan = Arc::new( + ProjectionExec::try_new( + vec![(expr_col("tag1", &schema), String::from("tag1"))], + Arc::new(RecordBatchesExec::new( + vec![Arc::new(chunk) as _], + schema, + None, + )), + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + let test = OptimizationTest::new(plan, opt); + insta::assert_yaml_snapshot!( + test, + @r###" + --- + input: + - " ProjectionExec: expr=[tag1@0 as tag1]" + - " RecordBatchesExec: chunks=1, projection=[tag1, tag2, field]" + output: + Ok: + - " RecordBatchesExec: chunks=1, projection=[tag1]" + "### + ); + + let recordbatches_exec = test + .output_plan() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let expected_schema = Schema::new(vec![Field::new("tag1", DataType::Utf8, true)]); + assert_eq!(recordbatches_exec.schema().as_ref(), &expected_schema); + } + + #[test] + fn test_integration() { + let schema = Arc::new(Schema::new(vec![ + Field::new("tag1", DataType::Utf8, true), + Field::new("tag2", DataType::Utf8, true), + Field::new("field1", DataType::UInt64, true), + Field::new("field2", DataType::UInt64, true), + ])); + let base_config = FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test://").unwrap(), + file_schema: Arc::clone(&schema), + file_groups: vec![], + statistics: Statistics::new_unknown(&schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![], + }; + let plan = Arc::new(ParquetExec::new(base_config, None, None)); + let plan = Arc::new(UnionExec::new(vec![plan])); + let plan_schema = plan.schema(); + let plan = Arc::new(DeduplicateExec::new( + plan, + vec![ + PhysicalSortExpr { + expr: expr_col("tag1", &plan_schema), + options: Default::default(), + }, + PhysicalSortExpr { + expr: expr_col("tag2", &plan_schema), + options: Default::default(), + }, + ], + false, + )); + let plan = + Arc::new(FilterExec::try_new(expr_string_cmp("tag2", &plan.schema()), plan).unwrap()); + let plan = Arc::new( + ProjectionExec::try_new( + vec![(expr_col("field1", &plan.schema()), String::from("field1"))], + plan, + ) + .unwrap(), + ); + let opt = ProjectionPushdown; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " ProjectionExec: expr=[field1@2 as field1]" + - " FilterExec: tag2@1 = foo" + - " DeduplicateExec: [tag1@0 ASC,tag2@1 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={0 groups: []}, projection=[tag1, tag2, field1, field2]" + output: + Ok: + - " ProjectionExec: expr=[field1@0 as field1]" + - " FilterExec: tag2@1 = foo" + - " ProjectionExec: expr=[field1@0 as field1, tag2@2 as tag2]" + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC]" + - " UnionExec" + - " ParquetExec: file_groups={0 groups: []}, projection=[field1, tag1, tag2]" + "### + ); + } + + #[test] + fn test_project_output_ordering_keep() { + let schema = schema(); + let projection = vec!["tag1", "tag2"]; + let output_ordering = vec![ + PhysicalSortExpr { + expr: expr_col("tag1", &schema), + options: Default::default(), + }, + PhysicalSortExpr { + expr: expr_col("tag2", &schema), + options: Default::default(), + }, + ]; + + insta::assert_yaml_snapshot!( + ProjectOutputOrdering::new(&schema, output_ordering, projection), + @r###" + --- + output_ordering: + - tag1@0 + - tag2@1 + projection: + - tag1 + - tag2 + projected_ordering: + - tag1@0 + - tag2@1 + "### + ); + } + + #[test] + fn test_project_output_ordering_project_prefix() { + let schema = schema(); + let projection = vec!["tag1"]; // prefix of the sort key + let output_ordering = vec![ + PhysicalSortExpr { + expr: expr_col("tag1", &schema), + options: Default::default(), + }, + PhysicalSortExpr { + expr: expr_col("tag2", &schema), + options: Default::default(), + }, + ]; + + insta::assert_yaml_snapshot!( + ProjectOutputOrdering::new(&schema, output_ordering, projection), + @r###" + --- + output_ordering: + - tag1@0 + - tag2@1 + projection: + - tag1 + projected_ordering: + - tag1@0 + "### + ); + } + + #[test] + fn test_project_output_ordering_project_non_prefix() { + let schema = schema(); + let projection = vec!["tag2"]; // in sort key, but not prefix + let output_ordering = vec![ + PhysicalSortExpr { + expr: expr_col("tag1", &schema), + options: Default::default(), + }, + PhysicalSortExpr { + expr: expr_col("tag2", &schema), + options: Default::default(), + }, + ]; + + insta::assert_yaml_snapshot!( + ProjectOutputOrdering::new(&schema, output_ordering, projection), + @r###" + --- + output_ordering: + - tag1@0 + - tag2@1 + projection: + - tag2 + projected_ordering: [] + "### + ); + } + + #[test] + fn test_project_output_ordering_projection_reorder() { + let schema = schema(); + let projection = vec!["tag2", "tag1", "field"]; // in different order than sort key + let output_ordering = vec![ + PhysicalSortExpr { + expr: expr_col("tag1", &schema), + options: Default::default(), + }, + PhysicalSortExpr { + expr: expr_col("tag2", &schema), + options: Default::default(), + }, + ]; + + insta::assert_yaml_snapshot!( + ProjectOutputOrdering::new(&schema, output_ordering, projection), + @r###" + --- + output_ordering: + - tag1@0 + - tag2@1 + projection: + - tag2 + - tag1 + - field + projected_ordering: + - tag1@1 + - tag2@0 + "### + ); + } + + #[test] + fn test_project_output_ordering_constant() { + let schema = schema(); + let projection = vec!["tag2"]; + let output_ordering = vec![ + // ordering by a constant is ignored + PhysicalSortExpr { + expr: datafusion::physical_plan::expressions::lit(1), + options: Default::default(), + }, + PhysicalSortExpr { + expr: expr_col("tag2", &schema), + options: Default::default(), + }, + ]; + + insta::assert_yaml_snapshot!( + ProjectOutputOrdering::new(&schema, output_ordering, projection), + @r###" + --- + output_ordering: + - "1" + - tag2@1 + projection: + - tag2 + projected_ordering: [] + "### + ); + } + + #[test] + fn test_project_output_ordering_constant_second_position() { + let schema = schema(); + let projection = vec!["tag2"]; + let output_ordering = vec![ + PhysicalSortExpr { + expr: expr_col("tag2", &schema), + options: Default::default(), + }, + // ordering by a constant is ignored + PhysicalSortExpr { + expr: datafusion::physical_plan::expressions::lit(1), + options: Default::default(), + }, + ]; + + insta::assert_yaml_snapshot!( + ProjectOutputOrdering::new(&schema, output_ordering, projection), + @r###" + --- + output_ordering: + - tag2@1 + - "1" + projection: + - tag2 + projected_ordering: + - tag2@0 + "### + ); + } + + /// project the output_ordering with the projection, + // derive serde to make a nice 'insta' snapshot + #[derive(Debug, Serialize)] + struct ProjectOutputOrdering { + output_ordering: Vec, + projection: Vec, + projected_ordering: Vec, + } + + impl ProjectOutputOrdering { + fn new( + schema: &Schema, + output_ordering: Vec, + projection: Vec<&'static str>, + ) -> Self { + let projected_fields: Fields = projection + .iter() + .map(|field_name| { + schema + .field_with_name(field_name) + .expect("finding field") + .clone() + }) + .collect(); + let projected_schema = Arc::new(Schema::new(projected_fields)); + + let projected_ordering = project_output_ordering(&output_ordering, projected_schema); + + let projected_ordering = match projected_ordering { + Ok(projected_ordering) => format_sort_exprs(&projected_ordering), + Err(e) => vec![e.to_string()], + }; + + Self { + output_ordering: format_sort_exprs(&output_ordering), + projection: projection.iter().map(|s| s.to_string()).collect(), + projected_ordering, + } + } + } + + fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("tag1", DataType::Utf8, true), + Field::new("tag2", DataType::Utf8, true), + Field::new("field", DataType::UInt64, true), + ])) + } + + fn format_sort_exprs(sort_exprs: &[PhysicalSortExpr]) -> Vec { + sort_exprs + .iter() + .map(|expr| { + let PhysicalSortExpr { expr, options: _ } = expr; + expr.to_string() + }) + .collect::>() + } + + fn expr_col(name: &str, schema: &SchemaRef) -> Arc { + Arc::new(Column::new_with_schema(name, schema).unwrap()) + } + + fn expr_string_cmp(col: &str, schema: &SchemaRef) -> Arc { + Arc::new(BinaryExpr::new( + expr_col(col, schema), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::from("foo"))), + )) + } + + fn expr_and(a: Arc, b: Arc) -> Arc { + Arc::new(BinaryExpr::new(a, Operator::And, b)) + } + + #[derive(Debug)] + struct TestExec { + schema: SchemaRef, + partitions: usize, + } + + impl TestExec { + fn new(schema: SchemaRef) -> Self { + Self::new_with_partitions(schema, 1) + } + + fn new_with_partitions(schema: SchemaRef, partitions: usize) -> Self { + Self { schema, partitions } + } + } + + impl ExecutionPlan for TestExec { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn output_partitioning(&self) -> datafusion::physical_plan::Partitioning { + datafusion::physical_plan::Partitioning::UnknownPartitioning(self.partitions) + } + + fn output_ordering(&self) -> Option<&[datafusion::physical_expr::PhysicalSortExpr]> { + None + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert!(children.is_empty()); + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!() + } + + fn statistics(&self) -> Result { + Ok(datafusion::physical_plan::Statistics::new_unknown( + &self.schema(), + )) + } + } + + impl DisplayAs for TestExec { + fn fmt_as( + &self, + _t: datafusion::physical_plan::DisplayFormatType, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "Test") + } + } +} diff --git a/iox_query/src/physical_optimizer/sort/mod.rs b/iox_query/src/physical_optimizer/sort/mod.rs new file mode 100644 index 0000000..9a9be8b --- /dev/null +++ b/iox_query/src/physical_optimizer/sort/mod.rs @@ -0,0 +1,8 @@ +//! Rules specific to [`SortExec`]. +//! +//! [`SortExec`]: datafusion::physical_plan::sorts::sort::SortExec + +pub mod order_union_sorted_inputs; +pub mod parquet_sortness; +pub mod push_sort_through_union; +pub mod util; diff --git a/iox_query/src/physical_optimizer/sort/order_union_sorted_inputs.rs b/iox_query/src/physical_optimizer/sort/order_union_sorted_inputs.rs new file mode 100644 index 0000000..0266108 --- /dev/null +++ b/iox_query/src/physical_optimizer/sort/order_union_sorted_inputs.rs @@ -0,0 +1,1487 @@ +use std::sync::Arc; + +use datafusion::{ + common::tree_node::{Transformed, TreeNode}, + config::ConfigOptions, + error::Result, + physical_optimizer::PhysicalOptimizerRule, + physical_plan::{ + displayable, expressions::Column, sorts::sort_preserving_merge::SortPreservingMergeExec, + union::UnionExec, ExecutionPlan, + }, +}; +use observability_deps::tracing::{trace, warn}; + +use crate::{ + physical_optimizer::sort::util::{collect_statistics_min_max, sort_by_value_ranges}, + provider::progressive_eval::ProgressiveEvalExec, +}; + +/// IOx specific optimization that eliminates a `SortPreservingMerge` +/// by reordering inputs in terms of their value ranges. If all inputs are non overlapping and ordered +/// by value range, they can be concatenated by `ProgressiveEval` while +/// maintaining the desired output order without actually merging. +/// +/// Find this structure: +/// SortPreservingMergeExec - on one column (DESC or ASC) +/// UnionExec +/// and if +/// - all inputs of UnionExec are already sorted (or has SortExec) with sortExpr also on time DESC or ASC accarsdingly and +/// - the streams do not overlap in values of the sorted column +/// do: +/// - order them by the sorted column DESC or ASC accordingly and +/// - replace SortPreservingMergeExec with ProgressiveEvalExec +/// +/// Notes: The difference between SortPreservingMergeExec & ProgressiveEvalExec +/// - SortPreservingMergeExec do the merge of sorted input streams. It needs each stream sorted but the streams themselves +/// can be in any random order and they can also overlap in values of sorted columns. +/// - ProgressiveEvalExec only outputs data in their input order of the streams and not do any merges. Thus in order to +/// output data in the right sort order, these three conditions must be true: +/// 1. Each input stream must sorted on the same column DESC or ASC accordingly +/// 2. The streams must be sorted on the column DESC or ASC accordingly +/// 3. The streams must not overlap in the values of that column. +/// +/// Example: for col_name ranges: +/// |--- r1---|-- r2 ---|-- r3 ---|-- r4 --| +/// +/// Here is what the input look like: +/// +/// SortPreservingMergeExec: time@2 DESC, fetch=1 +/// UnionExec +/// SortExec: expr=col_name@2 DESC <--- input stream with col_name range r3 +/// ... +/// SortExec: expr=col_name@2 DESC <--- input stream with col_name range r1 +/// ... +/// SortExec: expr=col_name@2 DESC <--- input stream with col_name range r4 +/// ... +/// SortExec: expr=col_name@2 DESC <--- input stream with col_name range r2 -- assuming this SortExec has 2 output sorted streams +/// ... +/// +/// The streams do not overlap in time, and they are already sorted by time DESC. +/// +/// The output will be the same except that all the input streams will be sorted by time DESC too and looks like +/// +/// SortPreservingMergeExec: time@2 DESC, fetch=1 +/// UnionExec +/// SortExec: expr=col_name@2 DESC <--- input stream with col_name range r1 +/// ... +/// SortPreservingMergeExec: -- need this extra to merge the 2 streams into one +/// SortExec: expr=col_name@2 DESC <--- input stream with col_name range r2 +/// ... +/// SortExec: expr=col_name@2 DESC <--- input stream with col_name range r3 +/// ... +/// SortExec: expr=col_name@2 DESC <--- input stream with col_name range r4 +/// ... +/// + +pub(crate) struct OrderUnionSortedInputs; + +impl PhysicalOptimizerRule for OrderUnionSortedInputs { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + plan.transform_up(&|plan| { + // Find SortPreservingMergeExec + let Some(sort_preserving_merge_exec) = + plan.as_any().downcast_ref::() + else { + return Ok(Transformed::No(plan)); + }; + + // Check if the sortExpr is only on one column + let sort_expr = sort_preserving_merge_exec.expr(); + if sort_expr.len() != 1 { + trace!( + ?sort_expr, + "-------- sortExpr is not on one column. No optimization" + ); + return Ok(Transformed::No(plan)); + }; + let Some(sorted_col) = sort_expr[0].expr.as_any().downcast_ref::() else { + trace!( + ?sort_expr, + "-------- sortExpr is not on pure column but expression. No optimization" + ); + return Ok(Transformed::No(plan)); + }; + let sort_options = sort_expr[0].options; + + // Find UnionExec + let Some(union_exec) = sort_preserving_merge_exec + .input() + .as_any() + .downcast_ref::() + else { + trace!("-------- SortPreservingMergeExec input is not UnionExec. No optimization"); + return Ok(Transformed::No(plan)); + }; + + // Check all inputs of UnionExec must be already sorted and on the same sort_expr of SortPreservingMergeExec + let Some(union_output_ordering) = union_exec.output_ordering() else { + warn!(plan=%displayable(plan.as_ref()).indent(false), "Union input to SortPreservingMerge is not sorted"); + return Ok(Transformed::No(plan)); + }; + + // Check if the first PhysicalSortExpr is the same as the sortExpr[0] in SortPreservingMergeExec + if sort_expr[0] != union_output_ordering[0] { + warn!(?sort_expr, ?union_output_ordering, plan=%displayable(plan.as_ref()).indent(false), "-------- Sort order of SortPreservingMerge and its children are different"); + return Ok(Transformed::No(plan)); + } + + let Some(value_ranges) = collect_statistics_min_max(union_exec.inputs(), sorted_col.name())? + else { + return Ok(Transformed::No(plan)); + }; + + // Sort the inputs by their value ranges + trace!("-------- value_ranges: {:?}", value_ranges); + let Some(plans_value_ranges) = + sort_by_value_ranges(union_exec.inputs().to_vec(), value_ranges, sort_options)? + else { + trace!("-------- inputs are not sorted by value ranges. No optimization"); + return Ok(Transformed::No(plan)); + }; + + // If each input of UnionExec outputs many sorted streams, data of different streams may overlap and + // even if they do not overlapped, their streams can be in any order. We need to (sort) merge them first + // to have a single output stream out to guarantee the output is sorted. + let new_inputs = plans_value_ranges.plans + .iter() + .map(|input| { + if input.output_partitioning().partition_count() > 1 { + // Add SortPreservingMergeExec on top of this input + let sort_preserving_merge_exec = Arc::new( + SortPreservingMergeExec::new(sort_expr.to_vec(), Arc::clone(input)) + .with_fetch(sort_preserving_merge_exec.fetch()), + ); + Ok(sort_preserving_merge_exec as _) + } else { + Ok(Arc::clone(input)) + } + }) + .collect::>>()?; + + let new_union_exec = Arc::new(UnionExec::new(new_inputs)); + + // Replace SortPreservingMergeExec with ProgressiveEvalExec + let progresive_eval_exec = Arc::new(ProgressiveEvalExec::new( + new_union_exec, + Some(plans_value_ranges.value_ranges), + sort_preserving_merge_exec.fetch(), + )); + + Ok(Transformed::Yes(progresive_eval_exec)) + }) + } + + fn name(&self) -> &str { + "order_union_sorted_inputs" + } + + fn schema_check(&self) -> bool { + true + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::{compute::SortOptions, datatypes::SchemaRef}; + use datafusion::{ + logical_expr::Operator, + physical_expr::PhysicalSortExpr, + physical_plan::{ + expressions::{BinaryExpr, Column}, + limit::GlobalLimitExec, + projection::ProjectionExec, + repartition::RepartitionExec, + sorts::{sort::SortExec, sort_preserving_merge::SortPreservingMergeExec}, + union::UnionExec, + ExecutionPlan, Partitioning, PhysicalExpr, + }, + scalar::ScalarValue, + }; + use schema::{InfluxFieldType, SchemaBuilder as IOxSchemaBuilder}; + + use crate::{ + physical_optimizer::{ + sort::order_union_sorted_inputs::OrderUnionSortedInputs, test_util::OptimizationTest, + }, + provider::{chunks_to_physical_nodes, DeduplicateExec, RecordBatchesExec}, + statistics::{column_statistics_min_max, compute_stats_column_min_max}, + test::{format_execution_plan, TestChunk}, + QueryChunk, CHUNK_ORDER_COLUMN_NAME, + }; + + // ------------------------------------------------------------------ + // Positive tests: the right structure found -> plan optimized + // ------------------------------------------------------------------ + + #[test] + fn test_limit_mix_record_batch_parquet_1_desc() { + test_helpers::maybe_start_logging(); + + // Input plan: + // + // GlobalLimitExec: skip=0, fetch=2 + // SortPreservingMerge: [time@2 DESC] + // UnionExec + // SortExec: expr=[time@2 DESC] -- time range [1000, 2000] + // ParquetExec -- [1000, 2000] + // SortExec: expr=[time@2 DESC] -- time range [2001, 3500] from combine time range of two record batches + // UnionExec + // RecordBatchesExec -- 3 chunks [2001, 3000] + // RecordBatchesExec -- 2 chunks [2500, 3500] + // + // Output plan: the 2 SortExecs will be swapped the order to have time range [2001, 3500] first + + let schema = schema(); + + let plan_parquet = parquet_exec_with_value_range(&schema, 1000, 2000); + let plan_batches1 = record_batches_exec_with_value_range(3, 2001, 3000); + let plan_batches2 = record_batches_exec_with_value_range(2, 2500, 3500); + + let plan_union_1 = Arc::new(UnionExec::new(vec![plan_batches1, plan_batches2])); + + let sort_order = ordering_with_options([("time", SortOp::Desc)], &schema); + let plan_sort1 = Arc::new(SortExec::new(sort_order.clone(), plan_parquet)); + let plan_sort2 = Arc::new(SortExec::new(sort_order.clone(), plan_union_1)); + + // min max of plan_sorted1 is [1000, 2000] + // structure of plan_sorted1 + let p_sort1 = Arc::clone(&plan_sort1) as Arc; + insta::assert_yaml_snapshot!( + format_execution_plan(&p_sort1), + @r###" + --- + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + let min_max_sort1 = compute_stats_column_min_max(&*plan_sort1, "time").unwrap(); + let min_max = column_statistics_min_max(&min_max_sort1).unwrap(); + assert_eq!( + min_max, + ( + ScalarValue::TimestampNanosecond(Some(1000), None), + ScalarValue::TimestampNanosecond(Some(2000), None) + ) + ); + // + // min max of plan_sorted2 is [2001, 3500] + let p_sort2 = Arc::clone(&plan_sort2) as Arc; + insta::assert_yaml_snapshot!( + format_execution_plan(&p_sort2), + @r###" + --- + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " UnionExec" + - " RecordBatchesExec: chunks=3, projection=[col1, col2, field1, time, __chunk_order]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + "### + ); + let min_max_sort2 = compute_stats_column_min_max(&*plan_sort2, "time").unwrap(); + let min_max = column_statistics_min_max(&min_max_sort2).unwrap(); + assert_eq!( + min_max, + ( + ScalarValue::TimestampNanosecond(Some(2001), None), + ScalarValue::TimestampNanosecond(Some(3500), None) + ) + ); + + let plan_union_2 = Arc::new(UnionExec::new(vec![plan_sort1, plan_sort2])); + + let plan_spm = Arc::new(SortPreservingMergeExec::new( + sort_order.clone(), + plan_union_2, + )); + + // min max of plan_spm is [1000, 3500] + let p_spm = Arc::clone(&plan_spm) as Arc; + insta::assert_yaml_snapshot!( + format_execution_plan(&p_spm), + @r###" + --- + - " SortPreservingMergeExec: [time@3 DESC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " UnionExec" + - " RecordBatchesExec: chunks=3, projection=[col1, col2, field1, time, __chunk_order]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + "### + ); + let min_max_spm = compute_stats_column_min_max(&*plan_spm, "time").unwrap(); + let min_max = column_statistics_min_max(&min_max_spm).unwrap(); + assert_eq!( + min_max, + ( + ScalarValue::TimestampNanosecond(Some(1000), None), + ScalarValue::TimestampNanosecond(Some(3500), None) + ) + ); + + let plan_limit = Arc::new(GlobalLimitExec::new(plan_spm, 0, Some(1))); + + // Output plan: the 2 SortExecs will be swapped the order + let opt = OrderUnionSortedInputs; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan_limit, opt), + @r###" + --- + input: + - " GlobalLimitExec: skip=0, fetch=1" + - " SortPreservingMergeExec: [time@3 DESC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " UnionExec" + - " RecordBatchesExec: chunks=3, projection=[col1, col2, field1, time, __chunk_order]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + output: + Ok: + - " GlobalLimitExec: skip=0, fetch=1" + - " ProgressiveEvalExec: input_ranges=[(TimestampNanosecond(2001, None), TimestampNanosecond(3500, None)), (TimestampNanosecond(1000, None), TimestampNanosecond(2000, None))]" + - " UnionExec" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " UnionExec" + - " RecordBatchesExec: chunks=3, projection=[col1, col2, field1, time, __chunk_order]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + } + + #[test] + fn test_limit_mix_record_batch_parquet_2_desc() { + test_helpers::maybe_start_logging(); + + // Input plan: + // + // GlobalLimitExec: skip=0, fetch=2 + // SortPreservingMerge: [time@2 DESC] + // UnionExec + // SortExec: expr=[time@2 DESC] -- time range [1000, 2000] + // ParquetExec -- [1000, 2000] + // SortExec: expr=[time@2 DESC] -- time range [2001, 3500] from combine time range of two record batches + // UnionExec + // SortExec: expr=[time@2 DESC] + // RecordBatchesExec -- 2 chunks [2500, 3500] + // ParquetExec -- [2001, 3000] + // + // Output plan: the 2 SortExecs will be swapped the order to have time range [2001, 3500] first + + let schema = schema(); + let order = ordering_with_options( + [ + ("col2", SortOp::Asc), + ("col1", SortOp::Asc), + ("time", SortOp::Asc), + ], + &schema, + ); + + let plan_parquet = parquet_exec_with_value_range(&schema, 1000, 2000); + let plan_parquet2 = parquet_exec_with_value_range(&schema, 2001, 3000); + let plan_batches = record_batches_exec_with_value_range(2, 2500, 3500); + + let plan_sort1 = Arc::new(SortExec::new(order.clone(), plan_batches)); + let plan_union_1 = Arc::new(UnionExec::new(vec![plan_sort1, plan_parquet2])); + + let sort_order = ordering_with_options([("time", SortOp::Desc)], &schema); + let plan_sort2 = Arc::new(SortExec::new(sort_order.clone(), plan_parquet)); + let plan_sort3 = Arc::new(SortExec::new(sort_order.clone(), plan_union_1)); + + let plan_union_2 = Arc::new(UnionExec::new(vec![plan_sort2, plan_sort3])); + + let plan_spm = Arc::new(SortPreservingMergeExec::new( + sort_order.clone(), + plan_union_2, + )); + + let plan_limit = Arc::new(GlobalLimitExec::new(plan_spm, 0, Some(1))); + + let opt = OrderUnionSortedInputs; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan_limit, opt), + @r###" + --- + input: + - " GlobalLimitExec: skip=0, fetch=1" + - " SortPreservingMergeExec: [time@3 DESC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[col2@1 ASC NULLS LAST,col1@0 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + output: + Ok: + - " GlobalLimitExec: skip=0, fetch=1" + - " ProgressiveEvalExec: input_ranges=[(TimestampNanosecond(2001, None), TimestampNanosecond(3500, None)), (TimestampNanosecond(1000, None), TimestampNanosecond(2000, None))]" + - " UnionExec" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[col2@1 ASC NULLS LAST,col1@0 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + } + + // test on non-time column & order desc + #[test] + fn test_limit_mix_record_batch_parquet_non_time_sort_desc() { + test_helpers::maybe_start_logging(); + + // Input plan: + // + // GlobalLimitExec: skip=0, fetch=2 + // SortPreservingMerge: [field1@2 DESC] + // UnionExec + // SortExec: expr=[field1@2 DESC] -- time range [1000, 2000] + // ParquetExec -- [1000, 2000] + // SortExec: expr=[field1@2 DESC] -- time range [2001, 3500] from combine time range of two record batches + // UnionExec + // SortExec: expr=[field1@2 DESC] + // RecordBatchesExec -- 2 chunks [2500, 3500] + // ParquetExec -- [2001, 3000] + // + // Output plan: the 2 SortExecs will be swapped the order to have time range [2001, 3500] first + + let schema = schema(); + let order = ordering_with_options( + [ + ("col2", SortOp::Asc), + ("col1", SortOp::Asc), + ("field1", SortOp::Asc), + ("time", SortOp::Asc), + ], + &schema, + ); + + let plan_parquet = parquet_exec_with_value_range(&schema, 1000, 2000); + let plan_parquet2 = parquet_exec_with_value_range(&schema, 2001, 3000); + let plan_batches = record_batches_exec_with_value_range(2, 2500, 3500); + + let plan_sort1 = Arc::new(SortExec::new(order.clone(), plan_batches)); + let plan_union_1 = Arc::new(UnionExec::new(vec![plan_sort1, plan_parquet2])); + + let sort_order = ordering_with_options([("field1", SortOp::Desc)], &schema); + let plan_sort2 = Arc::new(SortExec::new(sort_order.clone(), plan_parquet)); + let plan_sort3 = Arc::new(SortExec::new(sort_order.clone(), plan_union_1)); + + let plan_union_2 = Arc::new(UnionExec::new(vec![plan_sort2, plan_sort3])); + + let plan_spm = Arc::new(SortPreservingMergeExec::new( + sort_order.clone(), + plan_union_2, + )); + + let plan_limit = Arc::new(GlobalLimitExec::new(plan_spm, 0, Some(1))); + + // Output plan: the 2 SortExecs will be swapped the order + let opt = OrderUnionSortedInputs; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan_limit, opt), + @r###" + --- + input: + - " GlobalLimitExec: skip=0, fetch=1" + - " SortPreservingMergeExec: [field1@2 DESC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[field1@2 DESC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[field1@2 DESC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[col2@1 ASC NULLS LAST,col1@0 ASC NULLS LAST,field1@2 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + output: + Ok: + - " GlobalLimitExec: skip=0, fetch=1" + - " ProgressiveEvalExec: input_ranges=[(Int64(2001), Int64(3500)), (Int64(1000), Int64(2000))]" + - " UnionExec" + - " SortExec: expr=[field1@2 DESC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[col2@1 ASC NULLS LAST,col1@0 ASC NULLS LAST,field1@2 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[field1@2 DESC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + } + + // test on non-time column & order asc + #[test] + fn test_limit_mix_record_batch_parquet_non_time_sort_asc() { + test_helpers::maybe_start_logging(); + + // Input plan: + // + // GlobalLimitExec: skip=0, fetch=2 + // SortPreservingMerge: [field1@2 ASC] + // UnionExec + // SortExec: expr=[field1@2 ASC] -- time range [1000, 2000] + // ParquetExec -- [1000, 2000] + // SortExec: expr=[field1@2 ASC] -- time range [2001, 3500] from combine time range of two record batches + // UnionExec + // SortExec: expr=[field1@2 ASC] + // RecordBatchesExec -- 2 chunks [2500, 3500] + // ParquetExec -- [2001, 3000] + // + // Output plan: same as input plan + + let schema = schema(); + let order = ordering_with_options( + [ + ("col2", SortOp::Asc), + ("col1", SortOp::Asc), + ("field1", SortOp::Asc), + ("time", SortOp::Asc), + ], + &schema, + ); + + let plan_parquet = parquet_exec_with_value_range(&schema, 1000, 2000); + let plan_parquet2 = parquet_exec_with_value_range(&schema, 2001, 3000); + let plan_batches = record_batches_exec_with_value_range(2, 2500, 3500); + + let plan_sort1 = Arc::new(SortExec::new(order.clone(), plan_batches)); + let plan_union_1 = Arc::new(UnionExec::new(vec![plan_sort1, plan_parquet2])); + + let sort_order = ordering_with_options([("field1", SortOp::Asc)], &schema); + let plan_sort2 = Arc::new(SortExec::new(sort_order.clone(), plan_parquet)); + let plan_sort3 = Arc::new(SortExec::new(sort_order.clone(), plan_union_1)); + + let plan_union_2 = Arc::new(UnionExec::new(vec![plan_sort2, plan_sort3])); + + let plan_spm = Arc::new(SortPreservingMergeExec::new( + sort_order.clone(), + plan_union_2, + )); + + let plan_limit = Arc::new(GlobalLimitExec::new(plan_spm, 0, Some(1))); + + // input and output are the same + let opt = OrderUnionSortedInputs; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan_limit, opt), + @r###" + --- + input: + - " GlobalLimitExec: skip=0, fetch=1" + - " SortPreservingMergeExec: [field1@2 ASC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[field1@2 ASC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[field1@2 ASC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[col2@1 ASC NULLS LAST,col1@0 ASC NULLS LAST,field1@2 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + output: + Ok: + - " GlobalLimitExec: skip=0, fetch=1" + - " ProgressiveEvalExec: input_ranges=[(Int64(1000), Int64(2000)), (Int64(2001), Int64(3500))]" + - " UnionExec" + - " SortExec: expr=[field1@2 ASC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[field1@2 ASC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[col2@1 ASC NULLS LAST,col1@0 ASC NULLS LAST,field1@2 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + } + + // No limit & but the input is in the right sort preserving merge struct --> optimize + #[test] + fn test_spm_time_desc() { + test_helpers::maybe_start_logging(); + + // plan: + // SortPreservingMerge: [time@2 DESC] + // UnionExec + // SortExec: expr=[time@2 DESC] + // ParquetExec + // SortExec: expr=[time@2 DESC] + // UnionExec + // RecordBatchesExec + // ParquetExec + // + // Output: 2 SortExec are swapped + + let schema = schema(); + + let plan_parquet = parquet_exec_with_value_range(&schema, 1000, 2000); + let plan_parquet2 = parquet_exec_with_value_range(&schema, 2001, 3000); + let plan_batches = record_batches_exec_with_value_range(2, 2500, 3500); + + let plan_union_1 = Arc::new(UnionExec::new(vec![plan_batches, plan_parquet2])); + + let sort_order = ordering_with_options([("time", SortOp::Desc)], &schema); + let plan_sort1 = Arc::new(SortExec::new(sort_order.clone(), plan_parquet)); + let plan_sort2 = Arc::new(SortExec::new(sort_order.clone(), plan_union_1)); + + let plan_union_2 = Arc::new(UnionExec::new(vec![plan_sort1, plan_sort2])); + + let plan_spm = Arc::new(SortPreservingMergeExec::new( + sort_order.clone(), + plan_union_2, + )); + + // Output plan: the 2 SortExecs will be swapped the order + let opt = OrderUnionSortedInputs; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan_spm, opt), + @r###" + --- + input: + - " SortPreservingMergeExec: [time@3 DESC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + output: + Ok: + - " ProgressiveEvalExec: input_ranges=[(TimestampNanosecond(2001, None), TimestampNanosecond(3500, None)), (TimestampNanosecond(1000, None), TimestampNanosecond(2000, None))]" + - " UnionExec" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + } + + // No limit & but the input is in the right sort preserving merge struct --> optimize + #[test] + fn test_spm_non_time_desc() { + test_helpers::maybe_start_logging(); + + // plan: + // SortPreservingMerge: [field1@2 DESC] + // UnionExec + // SortExec: expr=[field1@2 DESC] + // ParquetExec + // SortExec: expr=[field1@2 DESC] + // UnionExec + // RecordBatchesExec + // ParquetExec + // + // Output: 2 SortExec are swapped + + let schema = schema(); + + let plan_parquet = parquet_exec_with_value_range(&schema, 1000, 2000); + let plan_parquet2 = parquet_exec_with_value_range(&schema, 2001, 3000); + let plan_batches = record_batches_exec_with_value_range(2, 2500, 3500); + + let plan_union_1 = Arc::new(UnionExec::new(vec![plan_batches, plan_parquet2])); + + let sort_order = ordering_with_options([("field1", SortOp::Desc)], &schema); + let plan_sort1 = Arc::new(SortExec::new(sort_order.clone(), plan_parquet)); + let plan_sort2 = Arc::new(SortExec::new(sort_order.clone(), plan_union_1)); + + let plan_union_2 = Arc::new(UnionExec::new(vec![plan_sort1, plan_sort2])); + + let plan_spm = Arc::new(SortPreservingMergeExec::new( + sort_order.clone(), + plan_union_2, + )); + + // Output plan: the 2 SortExecs will be swapped the order + let opt = OrderUnionSortedInputs; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan_spm, opt), + @r###" + --- + input: + - " SortPreservingMergeExec: [field1@2 DESC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[field1@2 DESC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[field1@2 DESC NULLS LAST]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + output: + Ok: + - " ProgressiveEvalExec: input_ranges=[(Int64(2001), Int64(3500)), (Int64(1000), Int64(2000))]" + - " UnionExec" + - " SortExec: expr=[field1@2 DESC NULLS LAST]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[field1@2 DESC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + } + + // No limit & but the input is in the right sort preserving merge struct --> optimize + #[test] + fn test_spm_non_time_asc() { + test_helpers::maybe_start_logging(); + + // plan: + // SortPreservingMerge: [field1@2 ASC] + // UnionExec + // SortExec: expr=[field1@2 ASC] + // ParquetExec + // SortExec: expr=[field1@2 ASC] + // UnionExec + // RecordBatchesExec + // ParquetExec + // + // Output: 2 SortExec ordered as above + + let schema = schema(); + + let plan_parquet = parquet_exec_with_value_range(&schema, 1000, 2000); + let plan_parquet2 = parquet_exec_with_value_range(&schema, 2001, 3000); + let plan_batches = record_batches_exec_with_value_range(2, 2500, 3500); + + let plan_union_1 = Arc::new(UnionExec::new(vec![plan_batches, plan_parquet2])); + + let sort_order = ordering_with_options([("field1", SortOp::Asc)], &schema); + let plan_sort1 = Arc::new(SortExec::new(sort_order.clone(), plan_parquet)); + let plan_sort2 = Arc::new(SortExec::new(sort_order.clone(), plan_union_1)); + + let plan_union_2 = Arc::new(UnionExec::new(vec![plan_sort1, plan_sort2])); + + let plan_spm = Arc::new(SortPreservingMergeExec::new( + sort_order.clone(), + plan_union_2, + )); + + // output stays the same as input + let opt = OrderUnionSortedInputs; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan_spm, opt), + @r###" + --- + input: + - " SortPreservingMergeExec: [field1@2 ASC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[field1@2 ASC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[field1@2 ASC NULLS LAST]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + output: + Ok: + - " ProgressiveEvalExec: input_ranges=[(Int64(1000), Int64(2000)), (Int64(2001), Int64(3500))]" + - " UnionExec" + - " SortExec: expr=[field1@2 ASC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[field1@2 ASC NULLS LAST]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + } + + // Plan starts with SortPreservingMerge and includes deduplication & projections. + // All conditions meet --> optimize + #[test] + fn test_spm_time_desc_with_dedupe_and_proj() { + test_helpers::maybe_start_logging(); + + // plan: + // SortPreservingMerge: [time@2 DESC] + // UnionExec + // SortExec: expr=[time@2 DESC] -- time range [1000, 2000] + // ProjectionExec: expr=[time] + // ParquetExec -- [1000, 2000] + // SortExec: expr=[time@2 DESC] -- time range [2001, 3500] from combine time range of record batches & parquet + // ProjectionExec: expr=[time] + // DeduplicateExec: [col1, col2, time] + // SortPreservingMergeExec: [col1 ASC, col2 ASC, time ASC] + // UnionExec + // SortExec: expr=[col1 ASC, col2 ASC, time ASC] + // RecordBatchesExec -- 2 chunks [2500, 3500] + // SortExec: expr=[col1 ASC, col2 ASC, time ASC] + // ParquetExec -- [2001, 3000] + // + // Output: 2 SortExec are swapped + + let schema = schema(); + + let final_sort_order = ordering_with_options([("time", SortOp::Desc)], &schema); + + // Sort plan of the first parquet: + // SortExec: expr=[time@2 DESC] -- time range [1000, 2000] + // ProjectionExec: expr=[time] + // ParquetExec + let plan_parquet_1 = parquet_exec_with_value_range(&schema, 1000, 2000); + let plan_projection_1 = Arc::new( + ProjectionExec::try_new( + vec![(expr_col("time", &schema), String::from("time"))], + plan_parquet_1, + ) + .unwrap(), + ); + let plan_sort1 = Arc::new(SortExec::new(final_sort_order.clone(), plan_projection_1)); + + // Sort plan of the second parquet and the record batch + // SortExec: expr=[time@2 DESC] -- time range [2001, 3500] from combine time range of record batches & parquet + // ProjectionExec: expr=[time] + // DeduplicateExec: [col1, col2, time] + // SortPreservingMergeExec: [col1 ASC, col2 ASC, time ASC] + // UnionExec + // SortExec: expr=[col1 ASC, col2 ASC, time ASC] + // RecordBatchesExec -- 2 chunks [2500, 3500] + // SortExec: expr=[col1 ASC, col2 ASC, time ASC] + // ParquetExec -- [2001, 3000] + let plan_parquet_2 = parquet_exec_with_value_range(&schema, 2001, 3000); + let plan_batches = record_batches_exec_with_value_range(2, 2500, 3500); + let dedupe_sort_order = ordering_with_options( + [ + ("col1", SortOp::Asc), + ("col2", SortOp::Asc), + ("time", SortOp::Asc), + ], + &schema, + ); + let plan_sort_rb = Arc::new(SortExec::new(dedupe_sort_order.clone(), plan_batches)); + let plan_sort_pq = Arc::new(SortExec::new(dedupe_sort_order.clone(), plan_parquet_2)); + let plan_union_1 = Arc::new(UnionExec::new(vec![plan_sort_rb, plan_sort_pq])); + let plan_spm_1 = Arc::new(SortPreservingMergeExec::new( + dedupe_sort_order.clone(), + plan_union_1, + )); + let plan_dedupe = Arc::new(DeduplicateExec::new(plan_spm_1, dedupe_sort_order, false)); + let plan_projection_2 = Arc::new( + ProjectionExec::try_new( + vec![(expr_col("time", &schema), String::from("time"))], + plan_dedupe, + ) + .unwrap(), + ); + let plan_sort2 = Arc::new(SortExec::new(final_sort_order.clone(), plan_projection_2)); + + // Union them together + let plan_union_2 = Arc::new(UnionExec::new(vec![plan_sort1, plan_sort2])); + + // SortPreservingMerge them + let plan_spm = Arc::new(SortPreservingMergeExec::new( + final_sort_order.clone(), + plan_union_2, + )); + + // compute statistics + let min_max_spm = compute_stats_column_min_max(&*plan_spm, "time").unwrap(); + let min_max = column_statistics_min_max(&min_max_spm).unwrap(); + assert_eq!( + min_max, + ( + ScalarValue::TimestampNanosecond(Some(1000), None), + ScalarValue::TimestampNanosecond(Some(3500), None) + ) + ); + + // Output plan: the 2 SortExecs will be swapped the order + let opt = OrderUnionSortedInputs; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan_spm, opt), + @r###" + --- + input: + - " SortPreservingMergeExec: [time@3 DESC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ProjectionExec: expr=[time@3 as time]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ProjectionExec: expr=[time@3 as time]" + - " DeduplicateExec: [col1@0 ASC NULLS LAST,col2@1 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " SortPreservingMergeExec: [col1@0 ASC NULLS LAST,col2@1 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[col1@0 ASC NULLS LAST,col2@1 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " SortExec: expr=[col1@0 ASC NULLS LAST,col2@1 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + output: + Ok: + - " ProgressiveEvalExec: input_ranges=[(TimestampNanosecond(2001, None), TimestampNanosecond(3500, None)), (TimestampNanosecond(1000, None), TimestampNanosecond(2000, None))]" + - " UnionExec" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ProjectionExec: expr=[time@3 as time]" + - " DeduplicateExec: [col1@0 ASC NULLS LAST,col2@1 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " SortPreservingMergeExec: [col1@0 ASC NULLS LAST,col2@1 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[col1@0 ASC NULLS LAST,col2@1 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " SortExec: expr=[col1@0 ASC NULLS LAST,col2@1 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ProjectionExec: expr=[time@3 as time]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + } + + // ------------------------------------------------------------------ + // Negative tests: the right structure not found -> nothing optimized + // ------------------------------------------------------------------ + + // Right stucture but sort on 2 columns --> plan stays the same + #[test] + fn test_negative_spm_2_column_sort_desc() { + test_helpers::maybe_start_logging(); + + // plan: + // SortPreservingMerge: [time@3 DESC, field1@2 DESC] + // UnionExec + // SortExec: expr=[time@3 DESC, field1@2 DESC] + // ParquetExec + // SortExec: expr=[time@3 DESC, field1@2 DESC] + // UnionExec + // RecordBatchesExec + // ParquetExec + // + // Output: same as input + + let schema = schema(); + + let plan_parquet = parquet_exec_with_value_range(&schema, 1000, 2000); + let plan_parquet2 = parquet_exec_with_value_range(&schema, 2001, 3000); + let plan_batches = record_batches_exec_with_value_range(2, 2500, 3500); + + let plan_union_1 = Arc::new(UnionExec::new(vec![plan_batches, plan_parquet2])); + + let sort_order = + ordering_with_options([("time", SortOp::Desc), ("field1", SortOp::Desc)], &schema); + let plan_sort1 = Arc::new(SortExec::new(sort_order.clone(), plan_parquet)); + let plan_sort2 = Arc::new(SortExec::new(sort_order.clone(), plan_union_1)); + + let plan_union_2 = Arc::new(UnionExec::new(vec![plan_sort1, plan_sort2])); + + let plan_spm = Arc::new(SortPreservingMergeExec::new( + sort_order.clone(), + plan_union_2, + )); + + // input and output are the same + let opt = OrderUnionSortedInputs; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan_spm, opt), + @r###" + --- + input: + - " SortPreservingMergeExec: [time@3 DESC NULLS LAST,field1@2 DESC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[time@3 DESC NULLS LAST,field1@2 DESC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[time@3 DESC NULLS LAST,field1@2 DESC NULLS LAST]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + output: + Ok: + - " SortPreservingMergeExec: [time@3 DESC NULLS LAST,field1@2 DESC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[time@3 DESC NULLS LAST,field1@2 DESC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[time@3 DESC NULLS LAST,field1@2 DESC NULLS LAST]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + } + + // No limit & random plan --> plan stay the same + #[test] + fn test_negative_no_limit() { + test_helpers::maybe_start_logging(); + + let schema = schema(); + let order = ordering_with_options( + [ + ("col2", SortOp::Asc), + ("col1", SortOp::Asc), + ("time", SortOp::Asc), + ], + &schema, + ); + + let plan_parquet = parquet_exec_with_value_range(&schema, 1000, 2000); + let plan_batches = record_batches_exec_with_value_range(2, 1500, 2500); + + let plan = Arc::new(UnionExec::new(vec![plan_batches, plan_parquet])); + let plan = + Arc::new(RepartitionExec::try_new(plan, Partitioning::RoundRobinBatch(8)).unwrap()); + let hash_exprs = order.iter().cloned().map(|e| e.expr).collect(); + let plan = + Arc::new(RepartitionExec::try_new(plan, Partitioning::Hash(hash_exprs, 8)).unwrap()); + let plan = Arc::new(SortExec::new(order.clone(), plan)); + let plan = Arc::new(DeduplicateExec::new(plan, order, true)); + + // input and output are the same + let opt = OrderUnionSortedInputs; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [col2@1 ASC NULLS LAST,col1@0 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " SortExec: expr=[col2@1 ASC NULLS LAST,col1@0 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " RepartitionExec: partitioning=Hash([col2@1, col1@0, time@3], 8), input_partitions=8" + - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=3" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + output: + Ok: + - " DeduplicateExec: [col2@1 ASC NULLS LAST,col1@0 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " SortExec: expr=[col2@1 ASC NULLS LAST,col1@0 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " RepartitionExec: partitioning=Hash([col2@1, col1@0, time@3], 8), input_partitions=8" + - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=3" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + } + + // has limit but no sort preserving merge --> plan stay the same + #[test] + fn test_negative_limit_no_preserving_merge() { + test_helpers::maybe_start_logging(); + + let schema = schema(); + + let plan_batches1 = record_batches_exec_with_value_range(1, 1000, 2000); + let plan_batches2 = record_batches_exec_with_value_range(3, 2001, 3000); + let plan_batches3 = record_batches_exec_with_value_range(2, 2500, 3500); + + let plan_union_1 = Arc::new(UnionExec::new(vec![plan_batches2, plan_batches3])); + + let sort_order = ordering_with_options([("time", SortOp::Desc)], &schema); + let plan_sort1 = Arc::new(SortExec::new(sort_order.clone(), plan_batches1)); + let plan_sort2 = Arc::new(SortExec::new(sort_order.clone(), plan_union_1)); + + let plan_union_2 = Arc::new(UnionExec::new(vec![plan_sort1, plan_sort2])); + + let plan_limit = Arc::new(GlobalLimitExec::new(plan_union_2, 0, Some(1))); + + // input and output are the same + let opt = OrderUnionSortedInputs; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan_limit, opt), + @r###" + --- + input: + - " GlobalLimitExec: skip=0, fetch=1" + - " UnionExec" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " RecordBatchesExec: chunks=1, projection=[col1, col2, field1, time, __chunk_order]" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " UnionExec" + - " RecordBatchesExec: chunks=3, projection=[col1, col2, field1, time, __chunk_order]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + output: + Ok: + - " GlobalLimitExec: skip=0, fetch=1" + - " UnionExec" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " RecordBatchesExec: chunks=1, projection=[col1, col2, field1, time, __chunk_order]" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " UnionExec" + - " RecordBatchesExec: chunks=3, projection=[col1, col2, field1, time, __chunk_order]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + "### + ); + } + + // right structure and same sort order but inputs of uion overlap --> plan stay the same + #[test] + fn test_negative_overlap() { + test_helpers::maybe_start_logging(); + + // Input plan: + // + // GlobalLimitExec: skip=0, fetch=2 + // SortPreservingMerge: [time@2 DESC] + // UnionExec + // SortExec: expr=[time@2 DESC] -- time range [1000, 2000] that overlaps with the other SorExec + // ParquetExec -- [1000, 2000] + // SortExec: expr=[time@2 DESC] -- time range [2000, 3500] from combine time range of two record batches + // UnionExec + // SortExec: expr=[time@2 DESC] + // RecordBatchesExec -- 2 chunks [2500, 3500] + // ParquetExec -- [2000, 3000] + + let schema = schema(); + let order = ordering_with_options( + [ + ("col2", SortOp::Asc), + ("col1", SortOp::Asc), + ("time", SortOp::Asc), + ], + &schema, + ); + + let plan_parquet = parquet_exec_with_value_range(&schema, 1000, 2000); + let plan_parquet2 = parquet_exec_with_value_range(&schema, 2000, 3000); + let plan_batches = record_batches_exec_with_value_range(2, 2500, 3500); + + let plan_sort1 = Arc::new(SortExec::new(order.clone(), plan_batches)); + let plan_union_1 = Arc::new(UnionExec::new(vec![plan_sort1, plan_parquet2])); + + let sort_order = ordering_with_options([("time", SortOp::Desc)], &schema); + let plan_sort2 = Arc::new(SortExec::new(sort_order.clone(), plan_parquet)); + let plan_sort3 = Arc::new(SortExec::new(sort_order.clone(), plan_union_1)); + + let plan_union_2 = Arc::new(UnionExec::new(vec![plan_sort2, plan_sort3])); + + let plan_spm = Arc::new(SortPreservingMergeExec::new( + sort_order.clone(), + plan_union_2, + )); + + let plan_limit = Arc::new(GlobalLimitExec::new(plan_spm, 0, Some(1))); + + // input and output are the same + let opt = OrderUnionSortedInputs; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan_limit, opt), + @r###" + --- + input: + - " GlobalLimitExec: skip=0, fetch=1" + - " SortPreservingMergeExec: [time@3 DESC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[col2@1 ASC NULLS LAST,col1@0 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + output: + Ok: + - " GlobalLimitExec: skip=0, fetch=1" + - " SortPreservingMergeExec: [time@3 DESC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[col2@1 ASC NULLS LAST,col1@0 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + } + + // No limit & but the input is in the right union struct --> plan stay the same + #[test] + fn test_negative_no_sortpreservingmerge_input_union() { + test_helpers::maybe_start_logging(); + + // plan: + // UnionExec + // SortExec: expr=[time@2 DESC] + // ParquetExec + // SortExec: expr=[time@2 DESC] + // UnionExec + // RecordBatchesExec + // ParquetExec + + let schema = schema(); + + let plan_parquet = parquet_exec_with_value_range(&schema, 1000, 2000); + let plan_parquet2 = parquet_exec_with_value_range(&schema, 2001, 3000); + let plan_batches = record_batches_exec_with_value_range(2, 2500, 3500); + + let plan_union_1 = Arc::new(UnionExec::new(vec![plan_batches, plan_parquet2])); + + let sort_order = ordering_with_options([("time", SortOp::Desc)], &schema); + let plan_sort1 = Arc::new(SortExec::new(sort_order.clone(), plan_parquet)); + let plan_sort2 = Arc::new(SortExec::new(sort_order.clone(), plan_union_1)); + + let plan_union_2 = Arc::new(UnionExec::new(vec![plan_sort1, plan_sort2])); + + // input and output are the same + let opt = OrderUnionSortedInputs; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan_union_2, opt), + @r###" + --- + input: + - " UnionExec" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + output: + Ok: + - " UnionExec" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + } + + // Projection expression (field + field) ==> not optimze. Plan stays the same + #[test] + fn test_negative_spm_time_desc_with_dedupe_and_proj_on_expr() { + test_helpers::maybe_start_logging(); + + // plan: + // SortPreservingMerge: [time@2 DESC] + // UnionExec + // SortExec: expr=[time@2 DESC] -- time range [1000, 2000] + // ProjectionExec: expr=[field1 + field1, time] <-- NOTE: has expresssion col1+col2 + // ParquetExec -- [1000, 2000] + // SortExec: expr=[time@2 DESC] -- time range [2001, 3500] from combine time range of record batches & parquet + // ProjectionExec: expr=[field1 + field1, time] <-- NOTE: has expresssion col1+col2 + // DeduplicateExec: [col1, col2, time] + // SortPreservingMergeExec: [col1 ASC, col2 ASC, time ASC] + // UnionExec + // SortExec: expr=[col1 ASC, col2 ASC, time ASC] + // RecordBatchesExec -- 2 chunks [2500, 3500] + // SortExec: expr=[col1 ASC, col2 ASC, time ASC] + // ParquetExec -- [2001, 3000] + + let schema = schema(); + + let final_sort_order = ordering_with_options([("time", SortOp::Desc)], &schema); + + // Sort plan of the first parquet: + // SortExec: expr=[time@2 DESC] -- time range [1000, 2000] + // ProjectionExec: expr=[field1 + field1, time] + // ParquetExec + let plan_parquet_1 = parquet_exec_with_value_range(&schema, 1000, 2000); + + let field_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new_with_schema("field1", &schema).unwrap()), + Operator::Plus, + Arc::new(Column::new_with_schema("field1", &schema).unwrap()), + )); + let plan_projection_1 = Arc::new( + ProjectionExec::try_new( + vec![ + (Arc::::clone(&field_expr), String::from("field")), + (expr_col("time", &schema), String::from("time")), + ], + plan_parquet_1, + ) + .unwrap(), + ); + let plan_sort1 = Arc::new(SortExec::new(final_sort_order.clone(), plan_projection_1)); + + // Sort plan of the second parquet and the record batch + // SortExec: expr=[time@2 DESC] -- time range [2001, 3500] from combine time range of record batches & parquet + // ProjectionExec: expr=[field1 + field1, time] + // DeduplicateExec: [col1, col2, time] + // SortPreservingMergeExec: [col1 ASC, col2 ASC, time ASC] + // UnionExec + // SortExec: expr=[col1 ASC, col2 ASC, time ASC] + // RecordBatchesExec -- 2 chunks [2500, 3500] + // SortExec: expr=[col1 ASC, col2 ASC, time ASC] + // ParquetExec -- [2001, 3000] + let plan_parquet_2 = parquet_exec_with_value_range(&schema, 2001, 3000); + let plan_batches = record_batches_exec_with_value_range(2, 2500, 3500); + let dedupe_sort_order = ordering_with_options( + [ + ("col1", SortOp::Asc), + ("col2", SortOp::Asc), + ("time", SortOp::Asc), + ], + &schema, + ); + let plan_sort_rb = Arc::new(SortExec::new(dedupe_sort_order.clone(), plan_batches)); + let plan_sort_pq = Arc::new(SortExec::new(dedupe_sort_order.clone(), plan_parquet_2)); + let plan_union_1 = Arc::new(UnionExec::new(vec![plan_sort_rb, plan_sort_pq])); + let plan_spm_1 = Arc::new(SortPreservingMergeExec::new( + dedupe_sort_order.clone(), + plan_union_1, + )); + let plan_dedupe = Arc::new(DeduplicateExec::new(plan_spm_1, dedupe_sort_order, false)); + let plan_projection_2 = Arc::new( + ProjectionExec::try_new( + vec![ + (field_expr, String::from("field")), + (expr_col("time", &schema), String::from("time")), + ], + plan_dedupe, + ) + .unwrap(), + ); + let plan_sort2 = Arc::new(SortExec::new(final_sort_order.clone(), plan_projection_2)); + + // Union them together + let plan_union_2 = Arc::new(UnionExec::new(vec![plan_sort1, plan_sort2])); + + // SortPreservingMerge them + let plan_spm = Arc::new(SortPreservingMergeExec::new( + final_sort_order.clone(), + plan_union_2, + )); + + // compute statistics: no stats becasue the ProjectionExec includes expression + let min_max_spm = compute_stats_column_min_max(&*plan_spm, "time").unwrap(); + let min_max = column_statistics_min_max(&min_max_spm); + assert!(min_max.is_none()); + + // output plan stays the same + let opt = OrderUnionSortedInputs; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan_spm, opt), + @r###" + --- + input: + - " SortPreservingMergeExec: [time@3 DESC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ProjectionExec: expr=[field1@2 + field1@2 as field, time@3 as time]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ProjectionExec: expr=[field1@2 + field1@2 as field, time@3 as time]" + - " DeduplicateExec: [col1@0 ASC NULLS LAST,col2@1 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " SortPreservingMergeExec: [col1@0 ASC NULLS LAST,col2@1 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[col1@0 ASC NULLS LAST,col2@1 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " SortExec: expr=[col1@0 ASC NULLS LAST,col2@1 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + output: + Ok: + - " SortPreservingMergeExec: [time@3 DESC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ProjectionExec: expr=[field1@2 + field1@2 as field, time@3 as time]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + - " SortExec: expr=[time@3 DESC NULLS LAST]" + - " ProjectionExec: expr=[field1@2 + field1@2 as field, time@3 as time]" + - " DeduplicateExec: [col1@0 ASC NULLS LAST,col2@1 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " SortPreservingMergeExec: [col1@0 ASC NULLS LAST,col2@1 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " UnionExec" + - " SortExec: expr=[col1@0 ASC NULLS LAST,col2@1 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " SortExec: expr=[col1@0 ASC NULLS LAST,col2@1 ASC NULLS LAST,time@3 ASC NULLS LAST]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + } + + // ------------------------------------------------------------------ + // Helper functions + // ------------------------------------------------------------------ + + fn schema() -> SchemaRef { + IOxSchemaBuilder::new() + .tag("col1") + .tag("col2") + .influx_field("field1", InfluxFieldType::Float) + .timestamp() + .influx_field(CHUNK_ORDER_COLUMN_NAME, InfluxFieldType::Integer) + .build() + .unwrap() + .into() + } + + fn expr_col(name: &str, schema: &SchemaRef) -> Arc { + Arc::new(Column::new_with_schema(name, schema).unwrap()) + } + + // test chunk with time range and field1's value range + fn test_chunk(min: i64, max: i64, parquet_data: bool) -> Arc { + let chunk = TestChunk::new("t") + .with_time_column_with_stats(Some(min), Some(max)) + .with_tag_column_with_stats("col1", Some("AL"), Some("MT")) + .with_tag_column_with_stats("col2", Some("MA"), Some("VY")) + .with_i64_field_column_with_stats("field1", Some(min), Some(max)); + + let chunk = if parquet_data { + chunk.with_dummy_parquet_file() + } else { + chunk + }; + + Arc::new(chunk) as Arc + } + + fn record_batches_exec_with_value_range( + n_chunks: usize, + min: i64, + max: i64, + ) -> Arc { + let chunks = std::iter::repeat(test_chunk(min, max, false)) + .take(n_chunks) + .collect::>(); + + Arc::new(RecordBatchesExec::new(chunks, schema(), None)) + } + + fn parquet_exec_with_value_range( + schema: &SchemaRef, + min: i64, + max: i64, + ) -> Arc { + let chunk = test_chunk(min, max, true); + let plan = chunks_to_physical_nodes(schema, None, vec![chunk], 1); + + if let Some(union_exec) = plan.as_any().downcast_ref::() { + if union_exec.inputs().len() == 1 { + Arc::clone(&union_exec.inputs()[0]) + } else { + plan + } + } else { + plan + } + } + + fn ordering_with_options( + cols: [(&str, SortOp); N], + schema: &SchemaRef, + ) -> Vec { + cols.into_iter() + .map(|col| PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema(col.0, schema.as_ref()).unwrap()), + options: SortOptions { + descending: col.1 == SortOp::Desc, + nulls_first: false, + }, + }) + .collect() + } + + #[derive(Debug, PartialEq)] + enum SortOp { + Asc, + Desc, + } +} diff --git a/iox_query/src/physical_optimizer/sort/parquet_sortness.rs b/iox_query/src/physical_optimizer/sort/parquet_sortness.rs new file mode 100644 index 0000000..c0f4a13 --- /dev/null +++ b/iox_query/src/physical_optimizer/sort/parquet_sortness.rs @@ -0,0 +1,658 @@ +use std::sync::Arc; + +use datafusion::{ + common::tree_node::{RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter}, + config::ConfigOptions, + datasource::physical_plan::{FileScanConfig, ParquetExec}, + error::Result, + physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}, + physical_optimizer::PhysicalOptimizerRule, + physical_plan::{sorts::sort::SortExec, ExecutionPlan}, +}; +use observability_deps::tracing::warn; + +use crate::config::IoxConfigExt; + +/// Trade wider fan-out of not having to sort parquet files. +/// +/// This will fan-out [`ParquetExec`] nodes beyond [`target_partitions`] if it is under a node that desires sorting, e.g.: +/// +/// - [`SortExec`] itself +/// - any other node that requires sorting, e.g. [`DeduplicateExec`] +/// +/// [`DeduplicateExec`]: crate::provider::DeduplicateExec +/// [`target_partitions`]: datafusion::common::config::ExecutionOptions::target_partitions +#[derive(Debug, Default)] +pub struct ParquetSortness; + +impl PhysicalOptimizerRule for ParquetSortness { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + plan.transform_down(&|plan| { + let Some(children_with_sort) = detect_children_with_desired_ordering(plan.as_ref()) + else { + return Ok(Transformed::No(plan)); + }; + let mut children_new = Vec::with_capacity(children_with_sort.len()); + for (child, desired_ordering) in children_with_sort { + let mut rewriter = ParquetSortnessRewriter { + config, + desired_ordering: &desired_ordering, + }; + let child = Arc::clone(&child).rewrite(&mut rewriter)?; + children_new.push(child); + } + + Ok(Transformed::Yes(plan.with_new_children(children_new)?)) + }) + } + + fn name(&self) -> &str { + "parquet_sortness" + } + + fn schema_check(&self) -> bool { + true + } +} + +type ChildWithSorting = (Arc, Vec); + +fn detect_children_with_desired_ordering( + plan: &dyn ExecutionPlan, +) -> Option> { + if let Some(sort_exec) = plan.as_any().downcast_ref::() { + return Some(vec![( + Arc::clone(sort_exec.input()), + sort_exec.expr().to_vec(), + )]); + } + + let required_input_ordering = plan.required_input_ordering(); + if !required_input_ordering.iter().all(|expr| expr.is_some()) { + // not all inputs require sorting, ignore it + return None; + } + + let children = plan.children(); + if children.len() != required_input_ordering.len() { + // this should normally not happen, but we ignore it + return None; + } + if children.is_empty() { + // leaf node + return None; + } + + Some( + children + .into_iter() + .zip( + required_input_ordering + .into_iter() + .map(|requirement| requirement.expect("just checked")) + .map(PhysicalSortRequirement::to_sort_exprs), + ) + .collect(), + ) +} + +#[derive(Debug)] +struct ParquetSortnessRewriter<'a> { + config: &'a ConfigOptions, + desired_ordering: &'a [PhysicalSortExpr], +} + +impl<'a> TreeNodeRewriter for ParquetSortnessRewriter<'a> { + type N = Arc; + + fn pre_visit(&mut self, node: &Self::N) -> Result { + if detect_children_with_desired_ordering(node.as_ref()).is_some() { + // another sort or sort-desiring node + Ok(RewriteRecursion::Stop) + } else { + Ok(RewriteRecursion::Continue) + } + } + + fn mutate(&mut self, node: Self::N) -> Result { + let Some(parquet_exec) = node.as_any().downcast_ref::() else { + // not a parquet exec + return Ok(node); + }; + + let base_config = parquet_exec.base_config(); + if base_config.output_ordering.is_empty() { + // no output ordering requested + return Ok(node); + } + + if base_config.file_groups.iter().all(|g| g.len() < 2) { + // already flat + return Ok(node); + } + + // Protect against degenerative plans + let n_files = base_config.file_groups.iter().map(Vec::len).sum::(); + let max_parquet_fanout = self + .config + .extensions + .get::() + .cloned() + .unwrap_or_default() + .max_parquet_fanout; + if n_files > max_parquet_fanout { + warn!( + n_files, + max_parquet_fanout, "cannot use pre-sorted parquet files, fan-out too wide" + ); + return Ok(node); + } + + let base_config = FileScanConfig { + file_groups: base_config + .file_groups + .iter() + .flat_map(|g| g.iter()) + .map(|f| vec![f.clone()]) + .collect(), + ..base_config.clone() + }; + let new_parquet_exec = + ParquetExec::new(base_config, parquet_exec.predicate().cloned(), None); + + // did this help? + if new_parquet_exec.output_ordering() == Some(self.desired_ordering) { + Ok(Arc::new(new_parquet_exec)) + } else { + Ok(node) + } + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; + use datafusion::{ + datasource::{listing::PartitionedFile, object_store::ObjectStoreUrl}, + physical_expr::PhysicalSortExpr, + physical_plan::{ + expressions::Column, placeholder_row::PlaceholderRowExec, sorts::sort::SortExec, + union::UnionExec, Statistics, + }, + }; + use object_store::{path::Path, ObjectMeta}; + + use crate::{ + chunk_order_field, + physical_optimizer::test_util::{assert_unknown_partitioning, OptimizationTest}, + provider::{DeduplicateExec, RecordBatchesExec}, + CHUNK_ORDER_COLUMN_NAME, + }; + + use super::*; + + #[test] + fn test_happy_path_sort() { + let schema = schema(); + let base_config = FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test://").unwrap(), + file_schema: Arc::clone(&schema), + file_groups: vec![vec![file(1), file(2)]], + statistics: Statistics::new_unknown(&schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![ordering(["col2", "col1"], &schema)], + }; + let inner = ParquetExec::new(base_config, None, None); + let plan = Arc::new( + SortExec::new(ordering(["col2", "col1"], &schema), Arc::new(inner)) + .with_fetch(Some(42)), + ); + let opt = ParquetSortness; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " SortExec: TopK(fetch=42), expr=[col2@1 ASC,col1@0 ASC]" + - " ParquetExec: file_groups={1 group: [[1.parquet, 2.parquet]]}, projection=[col1, col2, col3], output_ordering=[col2@1 ASC, col1@0 ASC]" + output: + Ok: + - " SortExec: TopK(fetch=42), expr=[col2@1 ASC,col1@0 ASC]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, col3], output_ordering=[col2@1 ASC, col1@0 ASC]" + "### + ); + } + + #[test] + fn test_happy_path_dedup() { + let schema = schema_with_chunk_order(); + let base_config = FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test://").unwrap(), + file_schema: Arc::clone(&schema), + file_groups: vec![vec![file(1), file(2)]], + statistics: Statistics::new_unknown(&schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![ordering(["col2", "col1", CHUNK_ORDER_COLUMN_NAME], &schema)], + }; + let inner = ParquetExec::new(base_config, None, None); + let plan = Arc::new(DeduplicateExec::new( + Arc::new(inner), + ordering(["col2", "col1"], &schema), + true, + )); + let opt = ParquetSortness; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC]" + - " ParquetExec: file_groups={1 group: [[1.parquet, 2.parquet]]}, projection=[col1, col2, col3, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, __chunk_order@3 ASC]" + output: + Ok: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, col3, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, __chunk_order@3 ASC]" + "### + ); + } + + #[test] + fn test_sort_partitioning() { + let schema = schema(); + let base_config = FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test://").unwrap(), + file_schema: Arc::clone(&schema), + file_groups: vec![vec![file(1), file(2)], vec![file(3)]], + statistics: Statistics::new_unknown(&schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![ordering(["col2", "col1"], &schema)], + }; + let inner = ParquetExec::new(base_config, None, None); + let plan = Arc::new( + SortExec::new(ordering(["col2", "col1"], &schema), Arc::new(inner)) + .with_preserve_partitioning(true) + .with_fetch(Some(42)), + ); + + assert_unknown_partitioning(plan.output_partitioning(), 2); + + let opt = ParquetSortness; + let test = OptimizationTest::new(plan, opt); + insta::assert_yaml_snapshot!( + test, + @r###" + --- + input: + - " SortExec: TopK(fetch=42), expr=[col2@1 ASC,col1@0 ASC]" + - " ParquetExec: file_groups={2 groups: [[1.parquet, 2.parquet], [3.parquet]]}, projection=[col1, col2, col3], output_ordering=[col2@1 ASC, col1@0 ASC]" + output: + Ok: + - " SortExec: TopK(fetch=42), expr=[col2@1 ASC,col1@0 ASC]" + - " ParquetExec: file_groups={3 groups: [[1.parquet], [2.parquet], [3.parquet]]}, projection=[col1, col2, col3], output_ordering=[col2@1 ASC, col1@0 ASC]" + "### + ); + + assert_unknown_partitioning(test.output_plan().unwrap().output_partitioning(), 3); + } + + #[test] + fn test_parquet_already_flat() { + let schema = schema(); + let base_config = FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test://").unwrap(), + file_schema: Arc::clone(&schema), + file_groups: vec![vec![file(1)], vec![file(2)]], + statistics: Statistics::new_unknown(&schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![ordering(["col2", "col1"], &schema)], + }; + let inner = ParquetExec::new(base_config, None, None); + let plan = Arc::new( + SortExec::new(ordering(["col2", "col1"], &schema), Arc::new(inner)) + .with_fetch(Some(42)), + ); + let opt = ParquetSortness; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " SortExec: TopK(fetch=42), expr=[col2@1 ASC,col1@0 ASC]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, col3], output_ordering=[col2@1 ASC, col1@0 ASC]" + output: + Ok: + - " SortExec: TopK(fetch=42), expr=[col2@1 ASC,col1@0 ASC]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, col3], output_ordering=[col2@1 ASC, col1@0 ASC]" + "### + ); + } + + #[test] + fn test_parquet_has_different_ordering() { + let schema = schema(); + let base_config = FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test://").unwrap(), + file_schema: Arc::clone(&schema), + file_groups: vec![vec![file(1), file(2)]], + statistics: Statistics::new_unknown(&schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![ordering(["col1", "col2"], &schema)], + }; + let inner = ParquetExec::new(base_config, None, None); + let plan = Arc::new( + SortExec::new(ordering(["col2", "col1"], &schema), Arc::new(inner)) + .with_fetch(Some(42)), + ); + let opt = ParquetSortness; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " SortExec: TopK(fetch=42), expr=[col2@1 ASC,col1@0 ASC]" + - " ParquetExec: file_groups={1 group: [[1.parquet, 2.parquet]]}, projection=[col1, col2, col3], output_ordering=[col1@0 ASC, col2@1 ASC]" + output: + Ok: + - " SortExec: TopK(fetch=42), expr=[col2@1 ASC,col1@0 ASC]" + - " ParquetExec: file_groups={1 group: [[1.parquet, 2.parquet]]}, projection=[col1, col2, col3], output_ordering=[col1@0 ASC, col2@1 ASC]" + "### + ); + } + + #[test] + fn test_parquet_has_no_ordering() { + let schema = schema(); + let base_config = FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test://").unwrap(), + file_schema: Arc::clone(&schema), + file_groups: vec![vec![file(1), file(2)]], + statistics: Statistics::new_unknown(&schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![], + }; + let inner = ParquetExec::new(base_config, None, None); + let plan = Arc::new( + SortExec::new(ordering(["col2", "col1"], &schema), Arc::new(inner)) + .with_fetch(Some(42)), + ); + let opt = ParquetSortness; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " SortExec: TopK(fetch=42), expr=[col2@1 ASC,col1@0 ASC]" + - " ParquetExec: file_groups={1 group: [[1.parquet, 2.parquet]]}, projection=[col1, col2, col3]" + output: + Ok: + - " SortExec: TopK(fetch=42), expr=[col2@1 ASC,col1@0 ASC]" + - " ParquetExec: file_groups={1 group: [[1.parquet, 2.parquet]]}, projection=[col1, col2, col3]" + "### + ); + } + + #[test] + fn test_fanout_limit() { + let schema = schema(); + let base_config = FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test://").unwrap(), + file_schema: Arc::clone(&schema), + file_groups: vec![vec![file(1), file(2), file(3)]], + statistics: Statistics::new_unknown(&schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![ordering(["col2", "col1"], &schema)], + }; + let inner = ParquetExec::new(base_config, None, None); + let plan = Arc::new( + SortExec::new(ordering(["col2", "col1"], &schema), Arc::new(inner)) + .with_fetch(Some(42)), + ); + let opt = ParquetSortness; + let mut config = ConfigOptions::default(); + config.extensions.insert(IoxConfigExt { + max_parquet_fanout: 2, + ..Default::default() + }); + insta::assert_yaml_snapshot!( + OptimizationTest::new_with_config(plan, opt, &config), + @r###" + --- + input: + - " SortExec: TopK(fetch=42), expr=[col2@1 ASC,col1@0 ASC]" + - " ParquetExec: file_groups={1 group: [[1.parquet, 2.parquet, 3.parquet]]}, projection=[col1, col2, col3], output_ordering=[col2@1 ASC, col1@0 ASC]" + output: + Ok: + - " SortExec: TopK(fetch=42), expr=[col2@1 ASC,col1@0 ASC]" + - " ParquetExec: file_groups={1 group: [[1.parquet, 2.parquet, 3.parquet]]}, projection=[col1, col2, col3], output_ordering=[col2@1 ASC, col1@0 ASC]" + "### + ); + } + + #[test] + fn test_other_node() { + let schema = schema(); + let inner = PlaceholderRowExec::new(Arc::clone(&schema)); + let plan = Arc::new( + SortExec::new(ordering(["col2", "col1"], &schema), Arc::new(inner)) + .with_fetch(Some(42)), + ); + let opt = ParquetSortness; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " SortExec: TopK(fetch=42), expr=[col2@1 ASC,col1@0 ASC]" + - " PlaceholderRowExec" + output: + Ok: + - " SortExec: TopK(fetch=42), expr=[col2@1 ASC,col1@0 ASC]" + - " PlaceholderRowExec" + "### + ); + } + + #[test] + fn test_does_not_touch_freestanding_parquet_exec() { + let schema = schema(); + let base_config = FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test://").unwrap(), + file_schema: Arc::clone(&schema), + file_groups: vec![vec![file(1), file(2)]], + statistics: Statistics::new_unknown(&schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![ordering(["col2", "col1"], &schema)], + }; + let plan = Arc::new(ParquetExec::new(base_config, None, None)); + let opt = ParquetSortness; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " ParquetExec: file_groups={1 group: [[1.parquet, 2.parquet]]}, projection=[col1, col2, col3], output_ordering=[col2@1 ASC, col1@0 ASC]" + output: + Ok: + - " ParquetExec: file_groups={1 group: [[1.parquet, 2.parquet]]}, projection=[col1, col2, col3], output_ordering=[col2@1 ASC, col1@0 ASC]" + "### + ); + } + + #[test] + fn test_ignore_outer_sort_if_inner_preform_resort() { + let schema = schema(); + let base_config = FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test://").unwrap(), + file_schema: Arc::clone(&schema), + file_groups: vec![vec![file(1), file(2)]], + statistics: Statistics::new_unknown(&schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![ordering(["col1", "col2"], &schema)], + }; + let plan = Arc::new(ParquetExec::new(base_config, None, None)); + let plan = + Arc::new(SortExec::new(ordering(["col2", "col1"], &schema), plan).with_fetch(Some(42))); + let plan = + Arc::new(SortExec::new(ordering(["col1", "col2"], &schema), plan).with_fetch(Some(42))); + let opt = ParquetSortness; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " SortExec: TopK(fetch=42), expr=[col1@0 ASC,col2@1 ASC]" + - " SortExec: TopK(fetch=42), expr=[col2@1 ASC,col1@0 ASC]" + - " ParquetExec: file_groups={1 group: [[1.parquet, 2.parquet]]}, projection=[col1, col2, col3], output_ordering=[col1@0 ASC, col2@1 ASC]" + output: + Ok: + - " SortExec: TopK(fetch=42), expr=[col1@0 ASC,col2@1 ASC]" + - " SortExec: TopK(fetch=42), expr=[col2@1 ASC,col1@0 ASC]" + - " ParquetExec: file_groups={1 group: [[1.parquet, 2.parquet]]}, projection=[col1, col2, col3], output_ordering=[col1@0 ASC, col2@1 ASC]" + "### + ); + } + + #[test] + fn test_honor_inner_sort_even_if_outer_preform_resort() { + let schema = schema(); + let base_config = FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test://").unwrap(), + file_schema: Arc::clone(&schema), + file_groups: vec![vec![file(1), file(2)]], + statistics: Statistics::new_unknown(&schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![ordering(["col1", "col2"], &schema)], + }; + let plan = Arc::new(ParquetExec::new(base_config, None, None)); + let plan = + Arc::new(SortExec::new(ordering(["col1", "col2"], &schema), plan).with_fetch(Some(42))); + let plan = + Arc::new(SortExec::new(ordering(["col2", "col1"], &schema), plan).with_fetch(Some(42))); + let opt = ParquetSortness; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " SortExec: TopK(fetch=42), expr=[col2@1 ASC,col1@0 ASC]" + - " SortExec: TopK(fetch=42), expr=[col1@0 ASC,col2@1 ASC]" + - " ParquetExec: file_groups={1 group: [[1.parquet, 2.parquet]]}, projection=[col1, col2, col3], output_ordering=[col1@0 ASC, col2@1 ASC]" + output: + Ok: + - " SortExec: TopK(fetch=42), expr=[col2@1 ASC,col1@0 ASC]" + - " SortExec: TopK(fetch=42), expr=[col1@0 ASC,col2@1 ASC]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, col3], output_ordering=[col1@0 ASC, col2@1 ASC]" + "### + ); + } + + #[test] + fn test_issue_idpe_17556() { + let schema = schema_with_chunk_order(); + + let base_config = FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test://").unwrap(), + file_schema: Arc::clone(&schema), + file_groups: vec![vec![file(1), file(2)]], + statistics: Statistics::new_unknown(&schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![ordering(["col2", "col1", CHUNK_ORDER_COLUMN_NAME], &schema)], + }; + let plan_parquet = Arc::new(ParquetExec::new(base_config, None, None)); + let plan_batches = Arc::new(RecordBatchesExec::new(vec![], Arc::clone(&schema), None)); + + let plan = Arc::new(UnionExec::new(vec![plan_batches, plan_parquet])); + let plan = Arc::new(DeduplicateExec::new( + plan, + ordering(["col2", "col1"], &schema), + true, + )); + let opt = ParquetSortness; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=0, projection=[col1, col2, col3, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[1.parquet, 2.parquet]]}, projection=[col1, col2, col3, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, __chunk_order@3 ASC]" + output: + Ok: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=0, projection=[col1, col2, col3, __chunk_order]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, col3, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, __chunk_order@3 ASC]" + "### + ); + } + + fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("col1", DataType::Utf8, true), + Field::new("col2", DataType::Utf8, true), + Field::new("col3", DataType::Utf8, true), + ])) + } + + fn schema_with_chunk_order() -> SchemaRef { + Arc::new(Schema::new( + schema() + .fields() + .iter() + .cloned() + .chain(std::iter::once(chunk_order_field())) + .collect::(), + )) + } + + fn file(n: u128) -> PartitionedFile { + PartitionedFile { + object_meta: ObjectMeta { + location: Path::parse(format!("{n}.parquet")).unwrap(), + last_modified: Default::default(), + size: 0, + e_tag: None, + version: None, + }, + partition_values: vec![], + range: None, + extensions: None, + } + } + + fn ordering(cols: [&str; N], schema: &SchemaRef) -> Vec { + cols.into_iter() + .map(|col| PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema(col, schema.as_ref()).unwrap()), + options: Default::default(), + }) + .collect() + } +} diff --git a/iox_query/src/physical_optimizer/sort/push_sort_through_union.rs b/iox_query/src/physical_optimizer/sort/push_sort_through_union.rs new file mode 100644 index 0000000..f76772a --- /dev/null +++ b/iox_query/src/physical_optimizer/sort/push_sort_through_union.rs @@ -0,0 +1,706 @@ +use std::sync::Arc; + +use datafusion::{ + common::{ + internal_err, + tree_node::{RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter}, + }, + config::ConfigOptions, + error::{DataFusionError, Result}, + physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}, + physical_optimizer::PhysicalOptimizerRule, + physical_plan::{ + repartition::RepartitionExec, sorts::sort::SortExec, union::UnionExec, ExecutionPlan, + }, +}; + +/// Pushes a [`SortExec`] through a [`UnionExec`], possibly +/// including multiple [`RepartitionExec`] nodes (converting them +/// to be sort-preserving in the process), provided that at least +/// one of the children of the union is already sorted. +/// +/// In other words, a typical plan like this +/// ```text +/// DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC] +/// SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC] +/// RepartitionExec: partitioning=Hash([col2@1, col1@0, time@3, __chunk_order@4], 8), input_partitions=8 +/// RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4 +/// UnionExec +/// RecordBatchesExec: batches_groups=2 batches=0 total_rows=0 +/// ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, time@3 ASC, __chunk_order@4 ASC] +/// ``` +/// will become: +/// ```text +/// DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC] +/// SortPreservingRepartitionExec: partitioning=Hash([col2@1, col1@0, time@3, __chunk_order@4], 8), input_partitions=8 +/// SortPreservingRepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4 +/// UnionExec +/// SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC] +/// RecordBatchesExec: batches_groups=2 batches=0 total_rows=0 +/// ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, time@3 ASC, __chunk_order@4 ASC] +/// ``` +/// +/// There is a tension between: +/// - Wanting to do sorts in parallel +/// - Sorting fewer rows +/// +/// DataFusion will not push down a sort through a `RepartitionExec` +/// because it could reduce the parallelism of the sort. However, +/// in IOx, unsorted children of `UnionExec` will tend to be +/// [`RecordBatchesExec`] which is likely to have many fewer rows than +/// other children which will tend to be [`ParquetExec`]. +/// So making this transformation will generally have a dramatic effect +/// on the amount of data being sorted. +/// +/// [`RecordBatchesExec`]: crate::provider::RecordBatchesExec +/// [`ParquetExec`]: datafusion::datasource::physical_plan::ParquetExec +pub(crate) struct PushSortThroughUnion; + +impl PhysicalOptimizerRule for PushSortThroughUnion { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + plan.transform_up(&|plan| { + let Some(sort_exec) = plan.as_any().downcast_ref::() else { + return Ok(Transformed::No(plan)); + }; + + if !sort_should_be_pushed_down(sort_exec)? { + return Ok(Transformed::No(plan)); + } + + let mut plan = Arc::clone(sort_exec.input()); + let mut rewriter = SortRewriter { + ordering: sort_exec.output_ordering().unwrap().to_vec(), + }; + + plan = plan.rewrite(&mut rewriter)?; + + // As a sanity check, make sure plan has the same ordering as before. + // If this fails, there is a bug in this optimization. + let Some(required_order) = sort_exec.output_ordering().map(sort_exprs_to_requirement) + else { + return internal_err!("No sort order after a sort"); + }; + + if !plan + .equivalence_properties() + .ordering_satisfy_requirement(&required_order) + { + return internal_err!("PushSortThroughUnion corrupted plan sort order"); + } + + Ok(Transformed::Yes(plan)) + }) + } + + fn name(&self) -> &str { + "push_sort_through_union" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// Returns true if the [`SortExec`] can be pushed down beneath a [`UnionExec`]. +fn sort_should_be_pushed_down(sort_exec: &SortExec) -> Result { + // Skip over any RepartitionExecs + let mut input = sort_exec.input(); + while input.as_any().is::() { + input = input + .as_any() + .downcast_ref::() + .expect("this must be a RepartitionExec") + .input(); + } + + let Some(union_exec) = input.as_any().downcast_ref::() else { + return Ok(false); + }; + + let Some(required_order) = sort_exec.output_ordering().map(sort_exprs_to_requirement) else { + return internal_err!("No sort order after a sort"); + }; + + // Push down the sort if any of the children are already sorted. + // This means we will need to sort fewer rows than if we didn't + // push down the sort. + Ok(union_exec.children().iter().any(|child| { + child + .equivalence_properties() + .ordering_satisfy_requirement(&required_order) + })) +} + +/// Rewrites a plan: +/// - Any [`RepartitionExec`] nodes are converted to be sort-preserving +/// - Any children of a [`UnionExec`] that are not sorted get a [`SortExec`] +/// added to them. +/// - Any other nodes will stop the rewrite. +struct SortRewriter { + ordering: Vec, +} + +impl TreeNodeRewriter for SortRewriter { + type N = Arc; + + fn pre_visit(&mut self, plan: &Self::N) -> Result { + if plan.as_any().is::() { + Ok(datafusion::common::tree_node::RewriteRecursion::Continue) + } else if plan.as_any().is::() { + Ok(datafusion::common::tree_node::RewriteRecursion::Mutate) + } else { + Ok(datafusion::common::tree_node::RewriteRecursion::Stop) + } + } + + fn mutate(&mut self, plan: Self::N) -> Result { + if let Some(repartition_exec) = plan.as_any().downcast_ref::() { + // Convert any RepartitionExec to be sort-preserving + Ok(Arc::new( + RepartitionExec::try_new( + Arc::clone(repartition_exec.input()), + repartition_exec.output_partitioning(), + )? + .with_preserve_order(), + )) + } else if let Some(union_exec) = plan.as_any().downcast_ref::() { + // Any children of the UnionExec that are not already sorted, + // need to be sorted. + let required_ordering = sort_exprs_to_requirement(self.ordering.as_ref()); + + let new_children = union_exec + .children() + .into_iter() + .map(|child| { + if !child + .equivalence_properties() + .ordering_satisfy_requirement(&required_ordering) + { + let sort_exec = SortExec::new(self.ordering.clone(), child) + .with_preserve_partitioning(true); + Arc::new(sort_exec) + } else { + child + } + }) + .collect(); + + Ok(Arc::new(UnionExec::new(new_children))) + } else { + Ok(plan) + } + } +} + +fn sort_exprs_to_requirement(sort_exprs: &[PhysicalSortExpr]) -> Vec { + sort_exprs + .iter() + .map(|sort_expr| sort_expr.clone().into()) + .collect() +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::datatypes::SchemaRef; + use datafusion::{ + datasource::{ + listing::PartitionedFile, + object_store::ObjectStoreUrl, + physical_plan::{FileScanConfig, ParquetExec}, + }, + physical_expr::PhysicalSortExpr, + physical_plan::{ + coalesce_batches::CoalesceBatchesExec, expressions::Column, + repartition::RepartitionExec, sorts::sort::SortExec, union::UnionExec, ExecutionPlan, + Partitioning, Statistics, + }, + }; + use object_store::{path::Path, ObjectMeta}; + use schema::{InfluxFieldType, SchemaBuilder as IOxSchemaBuilder}; + + use crate::{ + physical_optimizer::{ + sort::push_sort_through_union::PushSortThroughUnion, test_util::OptimizationTest, + }, + provider::{DeduplicateExec, RecordBatchesExec}, + test::TestChunk, + CHUNK_ORDER_COLUMN_NAME, + }; + + #[test] + fn test_push_sort_through_union() { + test_helpers::maybe_start_logging(); + + let schema = schema(); + let order = ordering(["col2", "col1", "time", CHUNK_ORDER_COLUMN_NAME], &schema); + + let plan_parquet = parquet_exec(&schema, &order); + let plan_batches = record_batches_exec(2); + + let plan = Arc::new(UnionExec::new(vec![plan_batches, plan_parquet])); + let plan = + Arc::new(RepartitionExec::try_new(plan, Partitioning::RoundRobinBatch(8)).unwrap()); + let hash_exprs = order.iter().cloned().map(|e| e.expr).collect(); + let plan = + Arc::new(RepartitionExec::try_new(plan, Partitioning::Hash(hash_exprs, 8)).unwrap()); + let plan = Arc::new(SortExec::new(order.clone(), plan)); + let plan = Arc::new(DeduplicateExec::new(plan, order, true)); + + let opt = PushSortThroughUnion; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " RepartitionExec: partitioning=Hash([col2@1, col1@0, time@3, __chunk_order@4], 8), input_partitions=8" + - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, time@3 ASC, __chunk_order@4 ASC]" + output: + Ok: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " RepartitionExec: partitioning=Hash([col2@1, col1@0, time@3, __chunk_order@4], 8), input_partitions=8, preserve_order=true, sort_exprs=col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC" + - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4, preserve_order=true, sort_exprs=col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC" + - " UnionExec" + - " SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, time@3 ASC, __chunk_order@4 ASC]" + "### + ); + } + + #[test] + fn test_push_sort_through_union_top_level_sort() { + test_helpers::maybe_start_logging(); + + let schema = schema(); + let order = ordering(["col2", "col1", "time", CHUNK_ORDER_COLUMN_NAME], &schema); + + let plan_parquet = parquet_exec(&schema, &order); + let plan_batches = record_batches_exec(2); + + let plan = Arc::new(UnionExec::new(vec![plan_batches, plan_parquet])); + let plan = + Arc::new(RepartitionExec::try_new(plan, Partitioning::RoundRobinBatch(8)).unwrap()); + let hash_exprs = order.iter().cloned().map(|e| e.expr).collect(); + let plan = + Arc::new(RepartitionExec::try_new(plan, Partitioning::Hash(hash_exprs, 8)).unwrap()); + let plan = Arc::new(SortExec::new(order.clone(), plan)); + let plan = Arc::new(DeduplicateExec::new(plan, order, true)); + + let output_order = ordering(["time"], &schema); + let plan = Arc::new(SortExec::new(output_order, plan)); + + // Nothing is done with the SortExec at the top level, because + // it does not match the pattern. + let opt = PushSortThroughUnion; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " SortExec: expr=[time@3 ASC]" + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " RepartitionExec: partitioning=Hash([col2@1, col1@0, time@3, __chunk_order@4], 8), input_partitions=8" + - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, time@3 ASC, __chunk_order@4 ASC]" + output: + Ok: + - " SortExec: expr=[time@3 ASC]" + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " RepartitionExec: partitioning=Hash([col2@1, col1@0, time@3, __chunk_order@4], 8), input_partitions=8, preserve_order=true, sort_exprs=col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC" + - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4, preserve_order=true, sort_exprs=col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC" + - " UnionExec" + - " SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, time@3 ASC, __chunk_order@4 ASC]" + "### + ); + } + + #[test] + fn test_push_sort_through_union_no_repartition() { + test_helpers::maybe_start_logging(); + + let schema = schema(); + let order = ordering(["col2", "col1", "time", CHUNK_ORDER_COLUMN_NAME], &schema); + + let plan_parquet = parquet_exec(&schema, &order); + let plan_batches = record_batches_exec(2); + + let plan = Arc::new(UnionExec::new(vec![plan_batches, plan_parquet])); + let plan = Arc::new(SortExec::new(order.clone(), plan)); + let plan = Arc::new(DeduplicateExec::new(plan, order, true)); + + // RepartitionExec does not need to be present for the optimization to apply + // (Although DF *will* handle this case) + let opt = PushSortThroughUnion; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, time@3 ASC, __chunk_order@4 ASC]" + output: + Ok: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " UnionExec" + - " SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, time@3 ASC, __chunk_order@4 ASC]" + "### + ); + } + + #[test] + fn test_no_sorted_children() { + test_helpers::maybe_start_logging(); + + let schema = schema(); + let order = ordering(["col2", "col1", "time", CHUNK_ORDER_COLUMN_NAME], &schema); + + let plan_batches_1 = record_batches_exec(2); + let plan_batches_2 = record_batches_exec(2); + + let plan = Arc::new(UnionExec::new(vec![plan_batches_1, plan_batches_2])); + let plan = + Arc::new(RepartitionExec::try_new(plan, Partitioning::RoundRobinBatch(8)).unwrap()); + let hash_exprs = order.iter().cloned().map(|e| e.expr).collect(); + let plan = + Arc::new(RepartitionExec::try_new(plan, Partitioning::Hash(hash_exprs, 8)).unwrap()); + let plan = Arc::new(SortExec::new(order.clone(), plan)); + let plan = Arc::new(DeduplicateExec::new(plan, order, true)); + + // No children of the union are sorted, so the sort will not be pushed down. + let opt = PushSortThroughUnion; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " RepartitionExec: partitioning=Hash([col2@1, col1@0, time@3, __chunk_order@4], 8), input_partitions=8" + - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + output: + Ok: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " RepartitionExec: partitioning=Hash([col2@1, col1@0, time@3, __chunk_order@4], 8), input_partitions=8" + - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + "### + ); + } + + #[test] + fn test_all_sorted_children() { + test_helpers::maybe_start_logging(); + + let schema = schema(); + let order = ordering(["col2", "col1", "time", CHUNK_ORDER_COLUMN_NAME], &schema); + + let plan_parquet_1 = parquet_exec(&schema, &order); + let plan_parquet_2 = parquet_exec(&schema, &order); + + let plan = Arc::new(UnionExec::new(vec![plan_parquet_1, plan_parquet_2])); + let plan = + Arc::new(RepartitionExec::try_new(plan, Partitioning::RoundRobinBatch(8)).unwrap()); + let hash_exprs = order.iter().cloned().map(|e| e.expr).collect(); + let plan = + Arc::new(RepartitionExec::try_new(plan, Partitioning::Hash(hash_exprs, 8)).unwrap()); + let plan = Arc::new(SortExec::new(order.clone(), plan)); + let plan = Arc::new(DeduplicateExec::new(plan, order, true)); + + // All children of the union are sorted, so RepartitionExec nodes are converted to + // be sort-preserving. + let opt = PushSortThroughUnion; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " RepartitionExec: partitioning=Hash([col2@1, col1@0, time@3, __chunk_order@4], 8), input_partitions=8" + - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4" + - " UnionExec" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, time@3 ASC, __chunk_order@4 ASC]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, time@3 ASC, __chunk_order@4 ASC]" + output: + Ok: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " RepartitionExec: partitioning=Hash([col2@1, col1@0, time@3, __chunk_order@4], 8), input_partitions=8, preserve_order=true, sort_exprs=col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC" + - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4, preserve_order=true, sort_exprs=col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC" + - " UnionExec" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, time@3 ASC, __chunk_order@4 ASC]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, time@3 ASC, __chunk_order@4 ASC]" + "### + ); + } + + #[test] + fn test_no_union() { + test_helpers::maybe_start_logging(); + + let schema = schema(); + let order = ordering(["col2", "col1", "time", CHUNK_ORDER_COLUMN_NAME], &schema); + + let plan = parquet_exec(&schema, &order); + + let plan = + Arc::new(RepartitionExec::try_new(plan, Partitioning::RoundRobinBatch(8)).unwrap()); + let hash_exprs = order.iter().cloned().map(|e| e.expr).collect(); + let plan = + Arc::new(RepartitionExec::try_new(plan, Partitioning::Hash(hash_exprs, 8)).unwrap()); + let plan = Arc::new(SortExec::new(order.clone(), plan)); + let plan = Arc::new(DeduplicateExec::new(plan, order, true)); + + // There is no union in the plan, so the pattern does not match. + let opt = PushSortThroughUnion; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " RepartitionExec: partitioning=Hash([col2@1, col1@0, time@3, __chunk_order@4], 8), input_partitions=8" + - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=2" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, time@3 ASC, __chunk_order@4 ASC]" + output: + Ok: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " RepartitionExec: partitioning=Hash([col2@1, col1@0, time@3, __chunk_order@4], 8), input_partitions=8" + - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=2" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, time@3 ASC, __chunk_order@4 ASC]" + "### + ); + } + + #[test] + fn test_two_sorts() { + test_helpers::maybe_start_logging(); + + let schema = schema(); + let order = ordering(["col2", "col1", "time", CHUNK_ORDER_COLUMN_NAME], &schema); + + let plan_parquet = parquet_exec(&schema, &order); + let plan_batches = record_batches_exec(2); + + let plan = Arc::new(UnionExec::new(vec![plan_batches, plan_parquet])); + let plan = + Arc::new(RepartitionExec::try_new(plan, Partitioning::RoundRobinBatch(8)).unwrap()); + let hash_exprs = order.iter().cloned().map(|e| e.expr).collect(); + let plan = + Arc::new(RepartitionExec::try_new(plan, Partitioning::Hash(hash_exprs, 8)).unwrap()); + let plan = Arc::new(SortExec::new(order.clone(), plan)); + let plan = Arc::new(SortExec::new(order.clone(), plan)); + let plan = Arc::new(DeduplicateExec::new(plan, order, true)); + + // With two identical sorts in the plan, both of them will be removed, + // because the transformation is applied bottom-up. + let opt = PushSortThroughUnion; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " RepartitionExec: partitioning=Hash([col2@1, col1@0, time@3, __chunk_order@4], 8), input_partitions=8" + - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, time@3 ASC, __chunk_order@4 ASC]" + output: + Ok: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " RepartitionExec: partitioning=Hash([col2@1, col1@0, time@3, __chunk_order@4], 8), input_partitions=8, preserve_order=true, sort_exprs=col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC" + - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4, preserve_order=true, sort_exprs=col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC" + - " UnionExec" + - " SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, time@3 ASC, __chunk_order@4 ASC]" + "### + ); + } + + #[test] + fn test_extra_node() { + test_helpers::maybe_start_logging(); + + let schema = schema(); + let order = ordering(["col2", "col1", "time", CHUNK_ORDER_COLUMN_NAME], &schema); + + let plan_parquet = parquet_exec(&schema, &order); + let plan_batches = record_batches_exec(2); + + let plan = Arc::new(UnionExec::new(vec![plan_batches, plan_parquet])); + let plan = + Arc::new(RepartitionExec::try_new(plan, Partitioning::RoundRobinBatch(8)).unwrap()); + let hash_exprs = order.iter().cloned().map(|e| e.expr).collect(); + let plan = + Arc::new(RepartitionExec::try_new(plan, Partitioning::Hash(hash_exprs, 8)).unwrap()); + let plan = Arc::new(CoalesceBatchesExec::new(plan, 4096)); + let plan = Arc::new(SortExec::new(order.clone(), plan)); + let plan = Arc::new(DeduplicateExec::new(plan, order, true)); + + // Extra nodes in the plan, like CoalesceBatchesExec, will break the pattern matching + // and prevent the transformation from occurring. + let opt = PushSortThroughUnion; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " CoalesceBatchesExec: target_batch_size=4096" + - " RepartitionExec: partitioning=Hash([col2@1, col1@0, time@3, __chunk_order@4], 8), input_partitions=8" + - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, time@3 ASC, __chunk_order@4 ASC]" + output: + Ok: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " CoalesceBatchesExec: target_batch_size=4096" + - " RepartitionExec: partitioning=Hash([col2@1, col1@0, time@3, __chunk_order@4], 8), input_partitions=8" + - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col2@1 ASC, col1@0 ASC, time@3 ASC, __chunk_order@4 ASC]" + "### + ); + } + + #[test] + fn test_wrong_order() { + test_helpers::maybe_start_logging(); + + let schema = schema(); + let order = ordering(["col2", "col1", "time", CHUNK_ORDER_COLUMN_NAME], &schema); + + let wrong_order = ordering(["col1", "col2", "time", CHUNK_ORDER_COLUMN_NAME], &schema); + let plan_parquet = parquet_exec(&schema, &wrong_order); + let plan_batches = record_batches_exec(2); + + let plan = Arc::new(UnionExec::new(vec![plan_batches, plan_parquet])); + let plan = + Arc::new(RepartitionExec::try_new(plan, Partitioning::RoundRobinBatch(8)).unwrap()); + let hash_exprs = order.iter().cloned().map(|e| e.expr).collect(); + let plan = + Arc::new(RepartitionExec::try_new(plan, Partitioning::Hash(hash_exprs, 8)).unwrap()); + let plan = Arc::new(SortExec::new(order.clone(), plan)); + let plan = Arc::new(DeduplicateExec::new(plan, order, true)); + + // The ParquetExec has the wrong output order so no children of the union have the right + // sort order. Therefore the optimization is not applied. + let opt = PushSortThroughUnion; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " RepartitionExec: partitioning=Hash([col2@1, col1@0, time@3, __chunk_order@4], 8), input_partitions=8" + - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col1@0 ASC, col2@1 ASC, time@3 ASC, __chunk_order@4 ASC]" + output: + Ok: + - " DeduplicateExec: [col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " SortExec: expr=[col2@1 ASC,col1@0 ASC,time@3 ASC,__chunk_order@4 ASC]" + - " RepartitionExec: partitioning=Hash([col2@1, col1@0, time@3, __chunk_order@4], 8), input_partitions=8" + - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4" + - " UnionExec" + - " RecordBatchesExec: chunks=2, projection=[col1, col2, field1, time, __chunk_order]" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[col1, col2, field1, time, __chunk_order], output_ordering=[col1@0 ASC, col2@1 ASC, time@3 ASC, __chunk_order@4 ASC]" + "### + ); + } + + fn record_batches_exec(n_chunks: usize) -> Arc { + let chunks = std::iter::repeat(Arc::new(TestChunk::new("t")) as _) + .take(n_chunks) + .collect::>(); + Arc::new(RecordBatchesExec::new(chunks, schema(), None)) + } + + fn parquet_exec(schema: &SchemaRef, order: &[PhysicalSortExpr]) -> Arc { + let base_config = FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test://").unwrap(), + file_schema: Arc::clone(schema), + file_groups: vec![vec![file(1)], vec![file(2)]], + statistics: Statistics::new_unknown(schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![order.to_vec()], + }; + Arc::new(ParquetExec::new(base_config, None, None)) + } + + fn schema() -> SchemaRef { + IOxSchemaBuilder::new() + .tag("col1") + .tag("col2") + .influx_field("field1", InfluxFieldType::Float) + .timestamp() + .influx_field(CHUNK_ORDER_COLUMN_NAME, InfluxFieldType::Integer) + .build() + .unwrap() + .into() + } + + fn file(n: u128) -> PartitionedFile { + PartitionedFile { + object_meta: ObjectMeta { + location: Path::parse(format!("{n}.parquet")).unwrap(), + last_modified: Default::default(), + size: 0, + e_tag: None, + version: None, + }, + partition_values: vec![], + range: None, + extensions: None, + } + } + + fn ordering(cols: [&str; N], schema: &SchemaRef) -> Vec { + cols.into_iter() + .map(|col| PhysicalSortExpr { + expr: Arc::new(Column::new_with_schema(col, schema.as_ref()).unwrap()), + options: Default::default(), + }) + .collect() + } +} diff --git a/iox_query/src/physical_optimizer/sort/util.rs b/iox_query/src/physical_optimizer/sort/util.rs new file mode 100644 index 0000000..274b016 --- /dev/null +++ b/iox_query/src/physical_optimizer/sort/util.rs @@ -0,0 +1,102 @@ +use std::sync::Arc; + +use crate::statistics::{column_statistics_min_max, compute_stats_column_min_max, overlap}; +use arrow::compute::{rank, SortOptions}; +use datafusion::{error::Result, physical_plan::ExecutionPlan, scalar::ScalarValue}; +use observability_deps::tracing::trace; + +/// Compute statistics for the given plans on a given column name +/// Return none if the statistics are not available +pub(crate) fn collect_statistics_min_max( + plans: &[Arc], + col_name: &str, +) -> Result>> { + // temp solution while waiting for DF's statistics to get mature + // Compute min max stats for all inputs of UnionExec on the sorted column + // https://github.com/apache/arrow-datafusion/issues/8078 + let col_stats = plans + .iter() + .map(|plan| compute_stats_column_min_max(&**plan, col_name)) + .collect::>>()?; + + // If min and max not available, return none + let mut value_ranges = Vec::with_capacity(col_stats.len()); + for stats in &col_stats { + let Some((min, max)) = column_statistics_min_max(stats) else { + trace!("-------- min_max not available"); + return Ok(None); + }; + + value_ranges.push((min, max)); + } + + // todo: use this when DF satistics is ready + // // Get statistics for the inputs of UnionExec on the sorted column + // let Some(value_ranges) = statistics_min_max(plans, col_name) + // else { + // return Ok(None); + // }; + + Ok(Some(value_ranges)) +} + +/// Plans and their corresponding value ranges +pub(crate) struct PlansValueRanges { + pub plans: Vec>, + // Min and max values of the plan on a specific column + pub value_ranges: Vec<(ScalarValue, ScalarValue)>, +} + +/// Sort the given plans by value ranges +/// Return none if +/// . the number of plans is not the same as the number of value ranges +/// . the value ranges overlap +pub(crate) fn sort_by_value_ranges( + plans: Vec>, + value_ranges: Vec<(ScalarValue, ScalarValue)>, + sort_options: SortOptions, +) -> Result> { + if plans.len() != value_ranges.len() { + trace!( + plans.len = plans.len(), + value_ranges.len = value_ranges.len(), + "--------- number of plans is not the same as the number of value ranges" + ); + return Ok(None); + } + + if overlap(&value_ranges)? { + trace!("--------- value ranges overlap"); + return Ok(None); + } + + // get the min value of each value range + let min_iter = value_ranges.iter().map(|(min, _)| min.clone()); + let mins = ScalarValue::iter_to_array(min_iter)?; + + // rank the min values + let ranks = rank(&*mins, Some(sort_options))?; + + // sort the plans by the ranks of their min values + let mut plan_rank_zip: Vec<(Arc, u32)> = + plans.into_iter().zip(ranks.clone()).collect::>(); + plan_rank_zip.sort_by(|(_, min1), (_, min2)| min1.cmp(min2)); + let plans = plan_rank_zip + .into_iter() + .map(|(plan, _)| plan) + .collect::>(); + + // Sort the value ranges by the ranks of their min values + let mut value_range_rank_zip: Vec<((ScalarValue, ScalarValue), u32)> = + value_ranges.into_iter().zip(ranks).collect::>(); + value_range_rank_zip.sort_by(|(_, min1), (_, min2)| min1.cmp(min2)); + let value_ranges = value_range_rank_zip + .into_iter() + .map(|(value_range, _)| value_range) + .collect::>(); + + Ok(Some(PlansValueRanges { + plans, + value_ranges, + })) +} diff --git a/iox_query/src/physical_optimizer/test_util.rs b/iox_query/src/physical_optimizer/test_util.rs new file mode 100644 index 0000000..d02c21a --- /dev/null +++ b/iox_query/src/physical_optimizer/test_util.rs @@ -0,0 +1,87 @@ +use std::sync::Arc; + +use datafusion::{ + config::ConfigOptions, + error::DataFusionError, + physical_optimizer::PhysicalOptimizerRule, + physical_plan::{ExecutionPlan, Partitioning}, +}; +use serde::Serialize; + +use crate::test::format_execution_plan; + +#[derive(Debug, Serialize)] +pub struct OptimizationTest { + input: Vec, + output: Result, String>, + + #[serde(skip_serializing)] + output_plan: Option>, +} + +impl OptimizationTest { + pub fn new(input_plan: Arc, opt: O) -> Self + where + O: PhysicalOptimizerRule, + { + Self::new_with_config(input_plan, opt, &ConfigOptions::default()) + } + + pub fn new_with_config( + input_plan: Arc, + opt: O, + config: &ConfigOptions, + ) -> Self + where + O: PhysicalOptimizerRule, + { + let input = format_execution_plan(&input_plan); + + let input_schema = input_plan.schema(); + + let output_result = opt.optimize(input_plan, config); + let output_plan = output_result.as_ref().ok().cloned(); + let output = output_result + .and_then(|plan| { + if opt.schema_check() && (plan.schema() != input_schema) { + Err(DataFusionError::External( + format!( + "Schema mismatch:\n\nBefore:\n{:?}\n\nAfter:\n{:?}", + input_schema, + plan.schema() + ) + .into(), + )) + } else { + Ok(plan) + } + }) + .map(|plan| format_execution_plan(&plan)) + .map_err(|e| e.to_string()); + + Self { + input, + output, + output_plan, + } + } + + pub fn output_plan(&self) -> Option<&Arc> { + self.output_plan.as_ref() + } +} + +/// Check if given partitioning is [`Partitioning::UnknownPartitioning`] with the given count. +/// +/// This is needed because [`PartialEq`] for [`Partitioning`] is specified as "unknown != unknown". +#[track_caller] +pub fn assert_unknown_partitioning(partitioning: Partitioning, n: usize) { + match partitioning { + Partitioning::UnknownPartitioning(n2) if n == n2 => {} + _ => panic!( + "Unexpected partitioning, wanted: {:?}, got: {:?}", + Partitioning::UnknownPartitioning(n), + partitioning + ), + } +} diff --git a/iox_query/src/physical_optimizer/tests.rs b/iox_query/src/physical_optimizer/tests.rs new file mode 100644 index 0000000..4e58227 --- /dev/null +++ b/iox_query/src/physical_optimizer/tests.rs @@ -0,0 +1,210 @@ +//! Optimizer edge cases. +//! +//! These are NOT part of the usual end2end query tests because they depend on very specific chunk arrangements that are +//! hard to reproduce in an end2end setting. + +use std::sync::Arc; + +use arrow::datatypes::DataType; +use datafusion::{ + common::DFSchema, + datasource::provider_as_source, + logical_expr::{col, count, lit, Expr, ExprSchemable, LogicalPlanBuilder}, + scalar::ScalarValue, +}; +use schema::sort::SortKey; +use test_helpers::maybe_start_logging; + +use crate::{ + exec::{DedicatedExecutors, Executor, ExecutorConfig, ExecutorType}, + provider::ProviderBuilder, + test::{format_execution_plan, TestChunk}, + QueryChunk, +}; + +/// Test that reconstructs specific case where parquet files may unnecessarily be sorted. +/// +/// See: +/// - +/// - +#[tokio::test] +async fn test_parquet_should_not_be_resorted() { + // DF session setup + let config = ExecutorConfig { + target_query_partitions: 16.try_into().unwrap(), + ..ExecutorConfig::testing() + }; + let exec = Executor::new_with_config_and_executors( + config, + Arc::new(DedicatedExecutors::new_testing()), + ); + let ctx = exec.new_context(ExecutorType::Query); + let state = ctx.inner().state(); + + // chunks + let c = TestChunk::new("t") + .with_tag_column("tag") + .with_time_column_with_full_stats(Some(0), Some(10), 10_000, None); + let c_mem = c.clone().with_may_contain_pk_duplicates(true); + let c_file = c + .clone() + .with_dummy_parquet_file() + .with_may_contain_pk_duplicates(false) + .with_sort_key(SortKey::from_columns([Arc::from("tag"), Arc::from("time")])); + let schema = c.schema().clone(); + let provider = ProviderBuilder::new("t".into(), schema) + .add_chunk(Arc::new(c_mem.clone().with_id(1).with_order(i64::MAX))) + .add_chunk(Arc::new(c_file.clone().with_id(2).with_order(2))) + .add_chunk(Arc::new(c_file.clone().with_id(3).with_order(3))) + .build() + .unwrap(); + + // initial plan + // NOTE: we NEED two time predicates for the bug to trigger! + let expr = col("time") + .gt(lit(ScalarValue::TimestampNanosecond(Some(0), None))) + .and(col("time").gt(lit(ScalarValue::TimestampNanosecond(Some(2), None)))); + + let plan = + LogicalPlanBuilder::scan("t".to_owned(), provider_as_source(Arc::new(provider)), None) + .unwrap() + .filter(expr) + .unwrap() + .aggregate( + std::iter::empty::(), + [count(lit(true)).alias("count")], + ) + .unwrap() + .project([col("count")]) + .unwrap() + .build() + .unwrap(); + + let plan = state.create_physical_plan(&plan).await.unwrap(); + + // The output of the parquet files should not be resorted + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " AggregateExec: mode=Final, gby=[], aggr=[count]" + - " CoalescePartitionsExec" + - " AggregateExec: mode=Partial, gby=[], aggr=[count]" + - " RepartitionExec: partitioning=RoundRobinBatch(16), input_partitions=1" + - " ProjectionExec: expr=[]" + - " DeduplicateExec: [tag@1 ASC,time@2 ASC]" + - " SortPreservingMergeExec: [tag@1 ASC,time@2 ASC,__chunk_order@0 ASC]" + - " UnionExec" + - " SortExec: expr=[tag@1 ASC,time@2 ASC,__chunk_order@0 ASC]" + - " CoalesceBatchesExec: target_batch_size=8192" + - " FilterExec: time@2 > 0 AND time@2 > 2" + - " RepartitionExec: partitioning=RoundRobinBatch(16), input_partitions=1" + - " RecordBatchesExec: chunks=1, projection=[__chunk_order, tag, time]" + - " SortExec: expr=[tag@1 ASC,time@2 ASC,__chunk_order@0 ASC]" + - " CoalesceBatchesExec: target_batch_size=8192" + - " FilterExec: time@2 > 0 AND time@2 > 2" + - " RepartitionExec: partitioning=RoundRobinBatch(16), input_partitions=2" + - " ParquetExec: file_groups={2 groups: [[2.parquet], [3.parquet]]}, projection=[__chunk_order, tag, time], output_ordering=[tag@1 ASC, time@2 ASC, __chunk_order@0 ASC], predicate=time@1 > 0 AND time@1 > 2, pruning_predicate=time_max@0 > 0 AND time_max@0 > 2" + "### + ); +} + +/// Bug reproducer for: +/// - +/// - +#[tokio::test] +async fn test_parquet_must_resorted() { + maybe_start_logging(); + + // DF session setup + let config = ExecutorConfig { + target_query_partitions: 6.try_into().unwrap(), + ..ExecutorConfig::testing() + }; + let exec = Executor::new_with_config_and_executors( + config, + Arc::new(DedicatedExecutors::new_testing()), + ); + let ctx = exec.new_context(ExecutorType::Query); + let state = ctx.inner().state(); + + // chunks + let c = TestChunk::new("t") + .with_tag_column("tag") + .with_f64_field_column("field") + .with_time_column_with_full_stats(Some(0), Some(10), 10_000, None) + .with_may_contain_pk_duplicates(false) + .with_sort_key(SortKey::from_columns([Arc::from("tag"), Arc::from("time")])); + let schema = c.schema().clone(); + let df_schema = DFSchema::try_from(schema.as_arrow().as_ref().clone()).unwrap(); + let provider = ProviderBuilder::new("t".into(), schema) + // need a small file followed by a big one + .add_chunk(Arc::new( + c.clone() + .with_id(1) + .with_order(1) + .with_dummy_parquet_file_and_size(1), + )) + .add_chunk(Arc::new( + c.clone() + .with_id(2) + .with_order(2) + .with_dummy_parquet_file_and_size(100_000_000), + )) + .build() + .unwrap(); + + // initial plan + let expr = col("tag") + .gt(lit("foo")) + .and(col("time").gt(lit(ScalarValue::TimestampNanosecond(Some(2), None)))) + .and( + col("field") + .cast_to(&DataType::Utf8, &df_schema) + .unwrap() + .not_eq(lit("")), + ); + + let plan = + LogicalPlanBuilder::scan("t".to_owned(), provider_as_source(Arc::new(provider)), None) + .unwrap() + .filter(expr) + .unwrap() + .project([col("tag")]) + .unwrap() + .build() + .unwrap(); + + let plan = state.create_physical_plan(&plan).await.unwrap(); + + // The output of the parquet files must be sorted prior to merging + // if the first file_group has more than one file + // + // Prior to https://github.com/influxdata/influxdb_iox/issues/9450, the plan + // called for the ParquetExec to read the files in parallel (using subranges) like: + // ``` + // {6 groups: [[1.parquet:0..1, 2.parquet:0..16666666], [2.parquet:16666666..33333333],... + // ``` + // + // Groups with more than one file produce an output partition that is the + // result of concatenating them together, so even if the output of each + // individual file is sorted, the output of the partition is not, due to the + // concatenation. + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " ProjectionExec: expr=[tag@1 as tag]" + - " CoalesceBatchesExec: target_batch_size=8192" + - " FilterExec: CAST(field@0 AS Utf8) != " + - " RepartitionExec: partitioning=RoundRobinBatch(6), input_partitions=1" + - " ProjectionExec: expr=[field@1 as field, tag@3 as tag]" + - " DeduplicateExec: [tag@3 ASC,time@2 ASC]" + - " SortPreservingMergeExec: [tag@3 ASC,time@2 ASC,__chunk_order@0 ASC]" + - " CoalesceBatchesExec: target_batch_size=8192" + - " FilterExec: tag@3 > foo AND time@2 > 2" + - " RepartitionExec: partitioning=RoundRobinBatch(6), input_partitions=2, preserve_order=true, sort_exprs=tag@3 ASC,time@2 ASC,__chunk_order@0 ASC" + - " ParquetExec: file_groups={2 groups: [[1.parquet], [2.parquet]]}, projection=[__chunk_order, field, time, tag], output_ordering=[tag@3 ASC, time@2 ASC, __chunk_order@0 ASC], predicate=tag@1 > foo AND time@2 > 2, pruning_predicate=tag_max@0 > foo AND time_max@1 > 2" + "### + ); +} diff --git a/iox_query/src/physical_optimizer/union/mod.rs b/iox_query/src/physical_optimizer/union/mod.rs new file mode 100644 index 0000000..df595eb --- /dev/null +++ b/iox_query/src/physical_optimizer/union/mod.rs @@ -0,0 +1,6 @@ +//! Rules specific to [`UnionExec`]. +//! +//! [`UnionExec`]: datafusion::physical_plan::union::UnionExec + +pub mod nested_union; +pub mod one_union; diff --git a/iox_query/src/physical_optimizer/union/nested_union.rs b/iox_query/src/physical_optimizer/union/nested_union.rs new file mode 100644 index 0000000..7a05139 --- /dev/null +++ b/iox_query/src/physical_optimizer/union/nested_union.rs @@ -0,0 +1,189 @@ +use std::sync::Arc; + +use datafusion::{ + common::tree_node::{Transformed, TreeNode}, + config::ConfigOptions, + error::Result, + physical_optimizer::PhysicalOptimizerRule, + physical_plan::{union::UnionExec, ExecutionPlan}, +}; + +/// Optimizer that replaces nested [`UnionExec`]s with a single level. +/// +/// # Example +/// ```yaml +/// --- +/// UnionExec: +/// - UnionExec: +/// - SomeExec1 +/// - SomeExec2 +/// - SomeExec3 +/// +/// --- +/// UnionExec: +/// - SomeExec1 +/// - SomeExec2 +/// - SomeExec3 +/// ``` +#[derive(Debug, Default)] +pub struct NestedUnion; + +impl PhysicalOptimizerRule for NestedUnion { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + plan.transform_up(&|plan| { + let plan_any = plan.as_any(); + + if let Some(union_exec) = plan_any.downcast_ref::() { + let children = union_exec.children(); + + let mut children_new = Vec::with_capacity(children.len()); + let mut found_union = false; + for child in children { + if let Some(union_child) = child.as_any().downcast_ref::() { + found_union = true; + children_new.append(&mut union_child.children()); + } else { + children_new.push(child) + } + } + + if found_union { + return Ok(Transformed::Yes(Arc::new(UnionExec::new(children_new)))); + } + } + + Ok(Transformed::No(plan)) + }) + } + + fn name(&self) -> &str { + "nested_union" + } + + fn schema_check(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion::physical_plan::empty::EmptyExec; + + use crate::physical_optimizer::test_util::OptimizationTest; + + use super::*; + + #[test] + #[should_panic(expected = "index out of bounds")] + fn test_union_empty() { + // empty UnionExecs cannot be created in the first place + UnionExec::new(vec![]); + } + + #[test] + fn test_union_not_nested() { + let plan = Arc::new(UnionExec::new(vec![other_node()])); + let opt = NestedUnion; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " UnionExec" + - " EmptyExec" + output: + Ok: + - " UnionExec" + - " EmptyExec" + "### + ); + } + + #[test] + fn test_union_nested() { + let plan = Arc::new(UnionExec::new(vec![ + Arc::new(UnionExec::new(vec![other_node(), other_node()])), + other_node(), + ])); + let opt = NestedUnion; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " UnionExec" + - " UnionExec" + - " EmptyExec" + - " EmptyExec" + - " EmptyExec" + output: + Ok: + - " UnionExec" + - " EmptyExec" + - " EmptyExec" + - " EmptyExec" + "### + ); + } + + #[test] + fn test_union_deeply_nested() { + let plan = Arc::new(UnionExec::new(vec![ + Arc::new(UnionExec::new(vec![ + other_node(), + Arc::new(UnionExec::new(vec![other_node()])), + ])), + other_node(), + ])); + let opt = NestedUnion; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " UnionExec" + - " UnionExec" + - " EmptyExec" + - " UnionExec" + - " EmptyExec" + - " EmptyExec" + output: + Ok: + - " UnionExec" + - " EmptyExec" + - " EmptyExec" + - " EmptyExec" + "### + ); + } + + #[test] + fn test_other_node() { + let plan = other_node(); + let opt = NestedUnion; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " EmptyExec" + output: + Ok: + - " EmptyExec" + "### + ); + } + + fn other_node() -> Arc { + Arc::new(EmptyExec::new(schema())) + } + + fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("c", DataType::UInt32, false)])) + } +} diff --git a/iox_query/src/physical_optimizer/union/one_union.rs b/iox_query/src/physical_optimizer/union/one_union.rs new file mode 100644 index 0000000..15f277a --- /dev/null +++ b/iox_query/src/physical_optimizer/union/one_union.rs @@ -0,0 +1,133 @@ +use std::sync::Arc; + +use datafusion::{ + common::tree_node::{Transformed, TreeNode}, + config::ConfigOptions, + error::Result, + physical_optimizer::PhysicalOptimizerRule, + physical_plan::{union::UnionExec, ExecutionPlan}, +}; + +/// Optimizer that replaces [`UnionExec`] with a single child node w/ the child note itself. +/// +/// # Example +/// ```yaml +/// --- +/// UnionExec: +/// - SomeExec1 +/// +/// --- +/// SomeExec1 +/// ``` +#[derive(Debug, Default)] +pub struct OneUnion; + +impl PhysicalOptimizerRule for OneUnion { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + plan.transform_up(&|plan| { + let plan_any = plan.as_any(); + + if let Some(union_exec) = plan_any.downcast_ref::() { + let mut children = union_exec.children(); + if children.len() == 1 { + return Ok(Transformed::Yes(children.remove(0))); + } + } + + Ok(Transformed::No(plan)) + }) + } + + fn name(&self) -> &str { + "one_union" + } + + fn schema_check(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion::physical_plan::empty::EmptyExec; + + use crate::physical_optimizer::test_util::OptimizationTest; + + use super::*; + + #[test] + #[should_panic(expected = "index out of bounds")] + fn test_union_empty() { + // empty UnionExecs cannot be created in the first place + UnionExec::new(vec![]); + } + + #[test] + fn test_union_one() { + let plan = Arc::new(UnionExec::new(vec![other_node()])); + let opt = OneUnion; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " UnionExec" + - " EmptyExec" + output: + Ok: + - " EmptyExec" + "### + ); + } + + #[test] + fn test_union_two() { + let plan = Arc::new(UnionExec::new(vec![other_node(), other_node()])); + let opt = OneUnion; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " UnionExec" + - " EmptyExec" + - " EmptyExec" + output: + Ok: + - " UnionExec" + - " EmptyExec" + - " EmptyExec" + "### + ); + } + + #[test] + fn test_other_node() { + let plan = other_node(); + let opt = OneUnion; + insta::assert_yaml_snapshot!( + OptimizationTest::new(plan, opt), + @r###" + --- + input: + - " EmptyExec" + output: + Ok: + - " EmptyExec" + "### + ); + } + + fn other_node() -> Arc { + Arc::new(EmptyExec::new(schema())) + } + + fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("c", DataType::UInt32, false)])) + } +} diff --git a/iox_query/src/plan.rs b/iox_query/src/plan.rs new file mode 100644 index 0000000..693ff90 --- /dev/null +++ b/iox_query/src/plan.rs @@ -0,0 +1,3 @@ +pub mod fieldlist; +pub mod seriesset; +pub mod stringset; diff --git a/iox_query/src/plan/fieldlist.rs b/iox_query/src/plan/fieldlist.rs new file mode 100644 index 0000000..e2e19f7 --- /dev/null +++ b/iox_query/src/plan/fieldlist.rs @@ -0,0 +1,57 @@ +use datafusion::logical_expr::LogicalPlan; + +use crate::exec::fieldlist::Field; +use std::collections::BTreeMap; + +pub type FieldSet = BTreeMap; + +/// A plan which produces a logical set of Fields (e.g. InfluxDB +/// Fields with name, and data type, and last_timestamp). +/// +/// known_values has a set of pre-computed values to be merged with +/// the extra_plans. +#[derive(Debug, Default)] +pub struct FieldListPlan { + /// Known values + pub known_values: FieldSet, + /// General plans + pub extra_plans: Vec, +} + +impl From> for FieldListPlan { + /// Create FieldList plan from a DataFusion LogicalPlan node, each + /// of which must produce fields in the correct format. The output + /// of each plan will be included into the final set. + fn from(plans: Vec) -> Self { + Self { + known_values: FieldSet::new(), + extra_plans: plans, + } + } +} + +impl From for FieldListPlan { + /// Create a StringSet plan from a single DataFusion LogicalPlan + /// node, which must produce fields in the correct format + fn from(plan: LogicalPlan) -> Self { + Self::from(vec![plan]) + } +} + +impl FieldListPlan { + pub fn new() -> Self { + Self::default() + } + + /// Append the other plan to ourselves + pub fn append_other(mut self, other: Self) -> Self { + self.extra_plans.extend(other.extra_plans); + self.known_values.extend(other.known_values); + self + } + + /// Append a single field to the known set of fields in this builder + pub fn append_field(&mut self, s: Field) { + self.known_values.insert(s.name.clone(), s); + } +} diff --git a/iox_query/src/plan/seriesset.rs b/iox_query/src/plan/seriesset.rs new file mode 100644 index 0000000..a158438 --- /dev/null +++ b/iox_query/src/plan/seriesset.rs @@ -0,0 +1,108 @@ +use std::sync::Arc; + +use datafusion::logical_expr::LogicalPlan; + +use crate::exec::field::FieldColumns; + +/// A plan that can be run to produce a logical stream of time series, +/// as represented as sequence of SeriesSets from a single DataFusion +/// plan, optionally grouped in some way. +/// +/// TODO: remove the tag/field designations below and attach a +/// `Schema` to the plan (which has the tag and field column +/// information natively) +#[derive(Debug)] +pub struct SeriesSetPlan { + /// The table name this came from + pub table_name: Arc, + + /// Datafusion plan to execute. The plan must produce + /// RecordBatches that have: + /// + /// * fields for each name in `tag_columns` and `field_columns` + /// * a timestamp column called 'time' + /// * each column in tag_columns must be a String (Utf8) + pub plan: LogicalPlan, + + /// The names of the columns that define tags. + /// + /// Note these are `Arc` strings because they are duplicated for + /// *each* resulting `SeriesSet` that is produced when this type + /// of plan is executed. + pub tag_columns: Vec>, + + /// The names of the columns which are "fields" + pub field_columns: FieldColumns, +} + +impl SeriesSetPlan { + /// Create a SeriesSetPlan that will not produce any Group items + pub fn new_from_shared_timestamp( + table_name: Arc, + plan: LogicalPlan, + tag_columns: Vec>, + field_columns: Vec>, + ) -> Self { + Self::new(table_name, plan, tag_columns, field_columns.into()) + } + + /// Create a SeriesSetPlan that will not produce any Group items + pub fn new( + table_name: Arc, + plan: LogicalPlan, + tag_columns: Vec>, + field_columns: FieldColumns, + ) -> Self { + Self { + table_name, + plan, + tag_columns, + field_columns, + } + } +} + +/// A container for plans which each produce a logical stream of +/// timeseries (from across many potential tables). +#[derive(Debug, Default)] +pub struct SeriesSetPlans { + /// Plans the generate Series, ordered by table_name. + /// + /// Each plan produces output that is sorted by tag keys (tag + /// column values) and then time. + pub plans: Vec, + + /// grouping keys, if any, that specify how the output series should be + /// sorted (aka grouped). If empty, means no grouping is needed + /// + /// There are several special values that are possible in `group_keys`: + /// + /// 1. _field (means group by field column name) + /// 2. _measurement (means group by the table name) + /// 3. _time (means group by the time column) + pub group_columns: Option>>, +} + +impl SeriesSetPlans { + pub fn into_inner(self) -> Vec { + self.plans + } +} + +impl SeriesSetPlans { + /// Create a new, ungrouped SeriesSetPlans + pub fn new(plans: Vec) -> Self { + Self { + plans, + group_columns: None, + } + } + + /// Group the created SeriesSetPlans + pub fn grouped_by(self, group_columns: Vec>) -> Self { + Self { + group_columns: Some(group_columns), + ..self + } + } +} diff --git a/iox_query/src/plan/stringset.rs b/iox_query/src/plan/stringset.rs new file mode 100644 index 0000000..8e49d2c --- /dev/null +++ b/iox_query/src/plan/stringset.rs @@ -0,0 +1,233 @@ +use std::sync::Arc; + +use arrow_util::util::str_iter_to_batch; +use datafusion::logical_expr::LogicalPlan; + +/// The name of the column containing table names returned by a call to +/// `table_names`. +const TABLE_NAMES_COLUMN_NAME: &str = "table"; + +use crate::{ + exec::stringset::{StringSet, StringSetRef}, + util::make_scan_plan, +}; + +use snafu::{ResultExt, Snafu}; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("Internal error converting to arrow: {}", source))] + InternalConvertingToArrow { source: arrow::error::ArrowError }, + + #[snafu(display("Internal error creating a plan for stringset: {}", source))] + InternalPlanningStringSet { + source: datafusion::error::DataFusionError, + }, +} + +pub type Result = std::result::Result; + +/// A plan which produces a logical set of Strings (e.g. tag +/// values). This includes variants with pre-calculated results as +/// well a variant that runs a full on DataFusion plan. +#[derive(Debug)] +pub enum StringSetPlan { + /// The results are known from metadata only without having to run + /// an actual datafusion plan + Known(StringSetRef), + + /// A DataFusion plan(s) to execute. Each plan must produce + /// RecordBatches with exactly one String column, though the + /// values produced by the plan may be repeated + /// + /// TODO: it would be cool to have a single datafusion LogicalPlan + /// that merged all the results together. However, no such Union + /// node exists at the time of writing, so we do the unioning in IOx + Plan(Vec), +} + +impl From for StringSetPlan { + /// Create a StringSetPlan from a StringSetRef + fn from(set: StringSetRef) -> Self { + Self::Known(set) + } +} + +impl From for StringSetPlan { + /// Create a StringSetPlan from a StringSet result, wrapping the error type + /// appropriately + fn from(set: StringSet) -> Self { + Self::Known(StringSetRef::new(set)) + } +} + +impl From> for StringSetPlan { + /// Create StringSet plan from a DataFusion LogicalPlan node, each + /// of which must produce a single output Utf8 column. The output + /// of each plan will be included into the final set. + fn from(plans: Vec) -> Self { + Self::Plan(plans) + } +} + +impl From for StringSetPlan { + /// Create a StringSet plan from a single DataFusion LogicalPlan + /// node which produces a single output Utf8 column. + fn from(plan: LogicalPlan) -> Self { + Self::Plan(vec![plan]) + } +} + +/// Builder for StringSet plans for appending multiple plans together +/// +/// If the values are known beforehand, the builder merges the +/// strings, otherwise it falls back to generic plans +#[derive(Debug, Default)] +pub struct StringSetPlanBuilder { + /// Known strings + strings: StringSet, + /// General plans + plans: Vec, +} + +impl StringSetPlanBuilder { + pub fn new() -> Self { + Self::default() + } + + /// Append the strings from the passed plan into ourselves if possible, or + /// passes on the plan + pub fn append_other(mut self, other: StringSetPlan) -> Self { + match other { + StringSetPlan::Known(ssref) => match Arc::try_unwrap(ssref) { + Ok(mut ss) => { + self.strings.append(&mut ss); + } + Err(ssref) => { + for s in &*ssref { + if !self.strings.contains(s) { + self.strings.insert(s.clone()); + } + } + } + }, + StringSetPlan::Plan(mut other_plans) => self.plans.append(&mut other_plans), + } + + self + } + + /// Return true if we know already that `s` is contained in the + /// StringSet. Note that if `contains()` returns false, `s` may be + /// in the stringset after execution. + pub fn contains(&self, s: impl AsRef) -> bool { + self.strings.contains(s.as_ref()) + } + + /// Append a single string to the known set of strings in this builder + pub fn append_string(&mut self, s: impl Into) { + self.strings.insert(s.into()); + } + + /// returns an iterator over the currently known strings in this builder + pub fn known_strings_iter(&self) -> impl Iterator { + self.strings.iter() + } + + /// Create a StringSetPlan that produces the deduplicated (union) + /// of all plans `append`ed to this builder. + pub fn build(self) -> Result { + let Self { strings, mut plans } = self; + + if plans.is_empty() { + // only a known set of strings + Ok(StringSetPlan::Known(Arc::new(strings))) + } else { + // Had at least one general plan, so need to use general + // purpose plan for the known strings + if !strings.is_empty() { + let batch = + str_iter_to_batch(TABLE_NAMES_COLUMN_NAME, strings.into_iter().map(Some)) + .context(InternalConvertingToArrowSnafu)?; + + let plan = make_scan_plan(batch).context(InternalPlanningStringSetSnafu)?; + + plans.push(plan) + } + + Ok(StringSetPlan::Plan(plans)) + } + } +} + +#[cfg(test)] +mod tests { + use crate::exec::{Executor, ExecutorType}; + + use super::*; + + #[test] + fn test_builder_empty() { + let plan = StringSetPlanBuilder::new().build().unwrap(); + let empty_ss = StringSet::new().into(); + if let StringSetPlan::Known(ss) = plan { + assert_eq!(ss, empty_ss) + } else { + panic!("unexpected type: {plan:?}") + } + } + + #[test] + fn test_builder_strings_only() { + let plan = StringSetPlanBuilder::new() + .append_other(to_string_set(&["foo", "bar"]).into()) + .append_other(to_string_set(&["bar", "baz"]).into()) + .build() + .unwrap(); + + let expected_ss = to_string_set(&["foo", "bar", "baz"]).into(); + + if let StringSetPlan::Known(ss) = plan { + assert_eq!(ss, expected_ss) + } else { + panic!("unexpected type: {plan:?}") + } + } + + #[derive(Debug)] + struct TestError {} + + impl std::fmt::Display for TestError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "this is an error") + } + } + + impl std::error::Error for TestError {} + + #[tokio::test] + async fn test_builder_plan() { + let batch = str_iter_to_batch("column_name", vec![Some("from_a_plan")]).unwrap(); + let df_plan = make_scan_plan(batch).unwrap(); + + // when a df plan is appended the whole plan should be different + let plan = StringSetPlanBuilder::new() + .append_other(to_string_set(&["foo", "bar"]).into()) + .append_other(vec![df_plan].into()) + .append_other(to_string_set(&["baz"]).into()) + .build() + .unwrap(); + + let expected_ss = to_string_set(&["foo", "bar", "baz", "from_a_plan"]).into(); + + assert!(matches!(plan, StringSetPlan::Plan(_))); + let exec = Executor::new_testing(); + let ctx = exec.new_context(ExecutorType::Query); + let ss = ctx.to_string_set(plan).await.unwrap(); + assert_eq!(ss, expected_ss); + } + + fn to_string_set(v: &[&str]) -> StringSet { + v.iter().map(|s| s.to_string()).collect::() + } +} diff --git a/iox_query/src/provider.rs b/iox_query/src/provider.rs new file mode 100644 index 0000000..3fab97e --- /dev/null +++ b/iox_query/src/provider.rs @@ -0,0 +1,641 @@ +//! Implementation of a DataFusion `TableProvider` in terms of `QueryChunk`s + +use async_trait::async_trait; +use std::{collections::HashSet, sync::Arc}; + +use arrow::{ + datatypes::{Fields, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}, + error::ArrowError, +}; +use datafusion::{ + datasource::{provider_as_source, TableProvider}, + error::{DataFusionError, Result as DataFusionResult}, + execution::context::SessionState, + logical_expr::{ + utils::{conjunction, split_conjunction}, + LogicalPlanBuilder, TableProviderFilterPushDown, TableType, + }, + physical_plan::{ + expressions::col as physical_col, filter::FilterExec, projection::ProjectionExec, + ExecutionPlan, + }, + prelude::Expr, + sql::TableReference, +}; +use observability_deps::tracing::trace; +use schema::{sort::SortKey, Schema}; + +use crate::{ + chunk_order_field, + util::{arrow_sort_key_exprs, df_physical_expr}, + QueryChunk, CHUNK_ORDER_COLUMN_NAME, +}; + +use snafu::{ResultExt, Snafu}; + +mod adapter; +mod deduplicate; +pub mod overlap; +mod physical; +pub(crate) mod progressive_eval; +mod record_batch_exec; +pub use self::overlap::group_potential_duplicates; +pub use deduplicate::{DeduplicateExec, RecordBatchDeduplicator}; +pub(crate) use physical::{chunks_to_physical_nodes, PartitionedFileExt}; + +pub(crate) use record_batch_exec::RecordBatchesExec; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display( + "Internal error: no chunk pruner provided to builder for {}", + table_name, + ))] + InternalNoChunkPruner { table_name: String }, + + #[snafu(display("Internal error: Cannot create projection select expr '{}'", source,))] + InternalSelectExpr { + source: datafusion::error::DataFusionError, + }, + + #[snafu(display("Internal error adding sort operator '{}'", source,))] + InternalSort { + source: datafusion::error::DataFusionError, + }, + + #[snafu(display("Internal error adding filter operator '{}'", source,))] + InternalFilter { + source: datafusion::error::DataFusionError, + }, + + #[snafu(display("Internal error adding projection operator '{}'", source,))] + InternalProjection { + source: datafusion::error::DataFusionError, + }, +} +pub type Result = std::result::Result; + +impl From for ArrowError { + // Wrap an error into an arrow error + fn from(e: Error) -> Self { + Self::ExternalError(Box::new(e)) + } +} + +impl From for DataFusionError { + // Wrap an error into a datafusion error + fn from(e: Error) -> Self { + Self::ArrowError(e.into(), None) + } +} + +/// Builds a `ChunkTableProvider` from a series of `QueryChunk`s +/// and ensures the schema across the chunks is compatible and +/// consistent. +#[derive(Debug)] +pub struct ProviderBuilder { + table_name: Arc, + schema: Schema, + chunks: Vec>, + deduplication: bool, +} + +impl ProviderBuilder { + pub fn new(table_name: Arc, schema: Schema) -> Self { + assert_eq!(schema.find_index_of(CHUNK_ORDER_COLUMN_NAME), None); + + Self { + table_name, + schema, + chunks: Vec::new(), + deduplication: true, + } + } + + pub fn with_enable_deduplication(mut self, enable_deduplication: bool) -> Self { + self.deduplication = enable_deduplication; + self + } + + /// Add a new chunk to this provider + pub fn add_chunk(mut self, chunk: Arc) -> Self { + self.chunks.push(chunk); + self + } + + /// Create the Provider + pub fn build(self) -> Result { + Ok(ChunkTableProvider { + iox_schema: self.schema, + table_name: self.table_name, + chunks: self.chunks, + deduplication: self.deduplication, + }) + } +} + +/// Implementation of a DataFusion TableProvider in terms of QueryChunks +/// +/// This allows DataFusion to see data from Chunks as a single table, as well as +/// push predicates and selections down to chunks +#[derive(Debug)] +pub struct ChunkTableProvider { + table_name: Arc, + /// The IOx schema (wrapper around Arrow Schemaref) for this table + iox_schema: Schema, + /// The chunks + chunks: Vec>, + /// do deduplication + deduplication: bool, +} + +impl ChunkTableProvider { + /// Return the IOx schema view for the data provided by this provider + pub fn iox_schema(&self) -> &Schema { + &self.iox_schema + } + + /// Return the Arrow schema view for the data provided by this provider + pub fn arrow_schema(&self) -> ArrowSchemaRef { + self.iox_schema.as_arrow() + } + + /// Return the table name + pub fn table_name(&self) -> &str { + self.table_name.as_ref() + } + + /// Running deduplication or not + pub fn deduplication(&self) -> bool { + self.deduplication + } + + /// Convert into a logical plan builder. + pub fn into_logical_plan_builder( + self: Arc, + ) -> Result { + let table_name = self.table_name().to_owned(); + let source = provider_as_source(self as _); + + // Scan all columns (DataFusion optimizer will prune this + // later if possible) + let projection = None; + + // Do not parse the tablename as a SQL identifer, but use as is + let table_ref = TableReference::bare(table_name); + LogicalPlanBuilder::scan(table_ref, source, projection) + } +} + +#[async_trait] +impl TableProvider for ChunkTableProvider { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + /// Schema with all available columns across all chunks + fn schema(&self) -> ArrowSchemaRef { + self.arrow_schema() + } + + /// Creates a plan like the following: + /// + /// ```text + /// Project (keep only columns needed in the rest of the plan) + /// Filter (optional, apply any push down predicates) + /// Deduplicate (optional, if chunks overlap) + /// ... Scan of Chunks (RecordBatchExec / ParquetExec / UnionExec, etc) ... + /// ``` + async fn scan( + &self, + ctx: &SessionState, + projection: Option<&Vec>, + filters: &[Expr], + _limit: Option, + ) -> std::result::Result, DataFusionError> { + trace!("Create a scan node for ChunkTableProvider"); + + let schema_with_chunk_order = Arc::new(ArrowSchema::new( + self.iox_schema + .as_arrow() + .fields + .iter() + .cloned() + .chain(std::iter::once(chunk_order_field())) + .collect::(), + )); + let pk = self.iox_schema().primary_key(); + let dedup_sort_key = SortKey::from_columns(pk.iter().copied()); + + // Create data stream from chunk data. This is the most simple data stream possible and contains duplicates and + // has no filters at all. + let plan = chunks_to_physical_nodes( + &schema_with_chunk_order, + None, + self.chunks.clone(), + ctx.config().target_partitions(), + ); + + // De-dup before doing anything else, because all logical expressions act on de-duplicated data. + let plan = if self.deduplication { + let sort_exprs = arrow_sort_key_exprs(&dedup_sort_key, &plan.schema()); + Arc::new(DeduplicateExec::new(plan, sort_exprs, true)) + } else { + plan + }; + + // Filter as early as possible (AFTER de-dup!). Predicate pushdown will eventually push down parts of this. + let plan = if let Some(expr) = filters.iter().cloned().reduce(|a, b| a.and(b)) { + let maybe_expr = if !self.deduplication { + let dedup_cols = pk.into_iter().collect::>(); + conjunction( + split_conjunction(&expr) + .into_iter() + .filter(|expr| { + let Ok(expr_cols) = expr.to_columns() else { + return false; + }; + expr_cols + .into_iter() + .all(|c| dedup_cols.contains(c.name.as_str())) + }) + .cloned(), + ) + } else { + Some(expr) + }; + + if let Some(expr) = maybe_expr { + Arc::new(FilterExec::try_new( + df_physical_expr(plan.schema(), expr)?, + plan, + )?) + } else { + plan + } + } else { + plan + }; + + // Project at last because it removes columns and hence other operations may fail. Projection pushdown will + // optimize that later. + // Always project because we MUST make sure that chunk order col doesn't leak to the user or to our parquet + // files. + let default_projection: Vec<_> = (0..self.iox_schema.len()).collect(); + let projection = projection.unwrap_or(&default_projection); + let select_exprs = self + .iox_schema() + .select_by_indices(projection) + .as_arrow() + .fields() + .iter() + .map(|f| { + let field_name = f.name(); + let physical_expr = + physical_col(field_name, &self.schema()).context(InternalSelectExprSnafu)?; + Ok((physical_expr, field_name.to_string())) + }) + .collect::>>()?; + + let plan = Arc::new(ProjectionExec::try_new(select_exprs, plan)?); + + Ok(plan) + } + + /// Filter pushdown specification + fn supports_filter_pushdown( + &self, + _filter: &Expr, + ) -> DataFusionResult { + if self.deduplication { + Ok(TableProviderFilterPushDown::Exact) + } else { + Ok(TableProviderFilterPushDown::Inexact) + } + } + + fn table_type(&self) -> TableType { + TableType::Base + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{ + exec::IOxSessionContext, + pruning::retention_expr, + test::{format_execution_plan, TestChunk}, + }; + use datafusion::prelude::{col, lit}; + + #[tokio::test] + async fn provider_scan_default() { + let table_name = "t"; + let chunk1 = Arc::new( + TestChunk::new(table_name) + .with_id(1) + .with_tag_column("tag1") + .with_tag_column("tag2") + .with_f64_field_column("field") + .with_time_column(), + ) as Arc; + let chunk2 = Arc::new( + TestChunk::new(table_name) + .with_id(2) + .with_dummy_parquet_file() + .with_tag_column("tag1") + .with_tag_column("tag2") + .with_f64_field_column("field") + .with_time_column(), + ) as Arc; + let schema = chunk1.schema().clone(); + + let ctx = IOxSessionContext::with_testing(); + let state = ctx.inner().state(); + + let provider = ProviderBuilder::new(Arc::from(table_name), schema) + .add_chunk(Arc::clone(&chunk1)) + .add_chunk(Arc::clone(&chunk2)) + .build() + .unwrap(); + + // simple plan + let plan = provider.scan(&state, None, &[], None).await.unwrap(); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " ProjectionExec: expr=[field@0 as field, tag1@1 as tag1, tag2@2 as tag2, time@3 as time]" + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[2.parquet]]}, projection=[field, tag1, tag2, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + + // projection + let plan = provider + .scan(&state, Some(&vec![1, 3]), &[], None) + .await + .unwrap(); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " ProjectionExec: expr=[tag1@1 as tag1, time@3 as time]" + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[2.parquet]]}, projection=[field, tag1, tag2, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + + // filters + let expr = vec![lit(false)]; + let expr_ref = expr.iter().collect::>(); + assert_eq!( + provider.supports_filters_pushdown(&expr_ref).unwrap(), + vec![TableProviderFilterPushDown::Exact] + ); + let plan = provider.scan(&state, None, &expr, None).await.unwrap(); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " ProjectionExec: expr=[field@0 as field, tag1@1 as tag1, tag2@2 as tag2, time@3 as time]" + - " FilterExec: false" + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[2.parquet]]}, projection=[field, tag1, tag2, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + + // limit pushdown is unimplemented at the moment + let plan = provider.scan(&state, None, &[], Some(1)).await.unwrap(); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " ProjectionExec: expr=[field@0 as field, tag1@1 as tag1, tag2@2 as tag2, time@3 as time]" + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[2.parquet]]}, projection=[field, tag1, tag2, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + } + + #[tokio::test] + async fn provider_scan_no_dedup() { + let table_name = "t"; + let chunk1 = Arc::new( + TestChunk::new(table_name) + .with_id(1) + .with_tag_column("tag1") + .with_tag_column("tag2") + .with_f64_field_column("field") + .with_time_column(), + ) as Arc; + let chunk2 = Arc::new( + TestChunk::new(table_name) + .with_id(2) + .with_dummy_parquet_file() + .with_tag_column("tag1") + .with_tag_column("tag2") + .with_f64_field_column("field") + .with_time_column(), + ) as Arc; + let schema = chunk1.schema().clone(); + + let ctx = IOxSessionContext::with_testing(); + let state = ctx.inner().state(); + + let provider = ProviderBuilder::new(Arc::from(table_name), schema) + .add_chunk(Arc::clone(&chunk1)) + .add_chunk(Arc::clone(&chunk2)) + .with_enable_deduplication(false) + .build() + .unwrap(); + + // simple plan + let plan = provider.scan(&state, None, &[], None).await.unwrap(); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " ProjectionExec: expr=[field@0 as field, tag1@1 as tag1, tag2@2 as tag2, time@3 as time]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[2.parquet]]}, projection=[field, tag1, tag2, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + + // projection + let plan = provider + .scan(&state, Some(&vec![1, 3]), &[], None) + .await + .unwrap(); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " ProjectionExec: expr=[tag1@1 as tag1, time@3 as time]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[2.parquet]]}, projection=[field, tag1, tag2, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + + // filters + // Expressions on fields are NOT pushed down because they cannot be pushed through de-dup. + let expr = vec![ + lit(false), + col("tag1").eq(lit("foo")), + col("field").eq(lit(1.0)), + ]; + let expr_ref = expr.iter().collect::>(); + assert_eq!( + provider.supports_filters_pushdown(&expr_ref).unwrap(), + vec![ + TableProviderFilterPushDown::Inexact, + TableProviderFilterPushDown::Inexact, + TableProviderFilterPushDown::Inexact + ] + ); + let plan = provider.scan(&state, None, &expr, None).await.unwrap(); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " ProjectionExec: expr=[field@0 as field, tag1@1 as tag1, tag2@2 as tag2, time@3 as time]" + - " FilterExec: false AND tag1@1 = CAST(foo AS Dictionary(Int32, Utf8))" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[2.parquet]]}, projection=[field, tag1, tag2, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + + // limit pushdown is unimplemented at the moment + let plan = provider.scan(&state, None, &[], Some(1)).await.unwrap(); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " ProjectionExec: expr=[field@0 as field, tag1@1 as tag1, tag2@2 as tag2, time@3 as time]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[2.parquet]]}, projection=[field, tag1, tag2, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + } + + #[tokio::test] + async fn provider_scan_retention() { + let table_name = "t"; + let pred = retention_expr(100); + let chunk1 = Arc::new( + TestChunk::new(table_name) + .with_id(1) + .with_tag_column("tag1") + .with_tag_column("tag2") + .with_f64_field_column("field") + .with_time_column(), + ) as Arc; + let chunk2 = Arc::new( + TestChunk::new(table_name) + .with_id(2) + .with_dummy_parquet_file() + .with_tag_column("tag1") + .with_tag_column("tag2") + .with_f64_field_column("field") + .with_time_column(), + ) as Arc; + let schema = chunk1.schema().clone(); + + let ctx = IOxSessionContext::with_testing(); + let state = ctx.inner().state(); + + let provider = ProviderBuilder::new(Arc::from(table_name), schema) + .add_chunk(Arc::clone(&chunk1)) + .add_chunk(Arc::clone(&chunk2)) + .build() + .unwrap(); + + // simple plan + let plan = provider + .scan(&state, None, &[pred.clone()], None) + .await + .unwrap(); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " ProjectionExec: expr=[field@0 as field, tag1@1 as tag1, tag2@2 as tag2, time@3 as time]" + - " FilterExec: time@3 > 100" + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[2.parquet]]}, projection=[field, tag1, tag2, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + + // projection + let plan = provider + .scan(&state, Some(&vec![1, 3]), &[pred.clone()], None) + .await + .unwrap(); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " ProjectionExec: expr=[tag1@1 as tag1, time@3 as time]" + - " FilterExec: time@3 > 100" + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[2.parquet]]}, projection=[field, tag1, tag2, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + + // filters + let expr = vec![lit(false), pred.clone()]; + let expr_ref = expr.iter().collect::>(); + assert_eq!( + provider.supports_filters_pushdown(&expr_ref).unwrap(), + vec![ + TableProviderFilterPushDown::Exact, + TableProviderFilterPushDown::Exact + ] + ); + let plan = provider.scan(&state, None, &expr, None).await.unwrap(); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " ProjectionExec: expr=[field@0 as field, tag1@1 as tag1, tag2@2 as tag2, time@3 as time]" + - " FilterExec: false AND time@3 > 100" + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[2.parquet]]}, projection=[field, tag1, tag2, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + + // limit pushdown is unimplemented at the moment + let plan = provider.scan(&state, None, &[pred], Some(1)).await.unwrap(); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " ProjectionExec: expr=[field@0 as field, tag1@1 as tag1, tag2@2 as tag2, time@3 as time]" + - " FilterExec: time@3 > 100" + - " DeduplicateExec: [tag1@1 ASC,tag2@2 ASC,time@3 ASC]" + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[field, tag1, tag2, time, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[2.parquet]]}, projection=[field, tag1, tag2, time, __chunk_order], output_ordering=[__chunk_order@4 ASC]" + "### + ); + } +} diff --git a/iox_query/src/provider/adapter.rs b/iox_query/src/provider/adapter.rs new file mode 100644 index 0000000..a0f1ad9 --- /dev/null +++ b/iox_query/src/provider/adapter.rs @@ -0,0 +1,514 @@ +//! Holds a stream that ensures chunks have the same (uniform) schema +use std::{collections::HashMap, sync::Arc}; + +use snafu::Snafu; +use std::task::{Context, Poll}; + +use arrow::{ + array::new_null_array, + datatypes::{DataType, SchemaRef}, + record_batch::RecordBatch, +}; +use datafusion::physical_plan::{ + metrics::BaselineMetrics, RecordBatchStream, SendableRecordBatchStream, +}; +use datafusion::{error::DataFusionError, scalar::ScalarValue}; +use futures::Stream; + +/// Schema creation / validation errors. +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("Internal error creating SchemaAdapterStream: input field '{}' had type '{:?}' which is different than output field '{}' which had type '{:?}'", + input_field_name, input_field_type, output_field_name, output_field_type,))] + InternalDataTypeMismatch { + input_field_name: String, + input_field_type: DataType, + output_field_name: String, + output_field_type: DataType, + }, + + #[snafu(display("Internal error creating SchemaAdapterStream: creating virtual value of type '{:?}' which is different than output field '{}' which had type '{:?}'", + field_type, output_field_name, output_field_type,))] + InternalDataTypeMismatchForVirtual { + field_type: DataType, + output_field_name: String, + output_field_type: DataType, + }, + + #[snafu(display("Internal error creating SchemaAdapterStream: the field '{}' is specified within the input and as a virtual column, don't know which one to choose", + field_name))] + InternalColumnBothInInputAndVirtual { field_name: String }, + + #[snafu(display("Internal error creating SchemaAdapterStream: field '{}' had output type '{:?}' and should be a NULL column but the field is flagged as 'not null'", + field_name, output_field_type,))] + InternalColumnNotNullable { + field_name: String, + output_field_type: DataType, + }, +} +pub type Result = std::result::Result; + +/// This stream wraps another underlying stream to ensure it produces +/// the specified schema. If the underlying stream produces a subset +/// of the columns specified in desired schema, this stream creates +/// arrays with NULLs to pad out the missing columns or creates "virtual" columns which contain a fixed given value. +/// +/// For example: +/// +/// If a table had schema with Cols A, B, C, and D, but the chunk (input) +/// stream only produced record batches with columns A and C. For D we provided a virtual value of "foo". This +/// stream would append a column of B / nulls to each record batch +/// that flowed through it and create a constant column D. +/// +/// ```text +/// +/// ┌────────────────┐ ┌───────────────────────────────┐ +/// │ ┌─────┐┌─────┐ │ │ ┌─────┐┌──────┐┌─────┐┌─────┐ │ +/// │ │ A ││ C │ │ │ │ A ││ B ││ C ││ D │ │ +/// │ │ - ││ - │ │ │ │ - ││ - ││ - ││ - │ │ +/// ┌──────────────┐ │ │ 1 ││ 10 │ │ ┌──────────────┐ │ │ 1 ││ NULL ││ 10 ││ foo │ │ +/// │ Input │ │ │ 2 ││ 20 │ │ │ Adapter │ │ │ 2 ││ NULL ││ 20 ││ foo │ │ +/// │ Stream ├────▶ │ │ 3 ││ 30 │ │────▶│ Stream ├───▶│ │ 3 ││ NULL ││ 30 ││ foo │ │ +/// └──────────────┘ │ │ 4 ││ 40 │ │ └──────────────┘ │ │ 4 ││ NULL ││ 40 ││ foo │ │ +/// │ └─────┘└─────┘ │ │ └─────┘└──────┘└─────┘└─────┘ │ +/// │ │ │ │ +/// │ Record Batch │ │ Record Batch │ +/// └────────────────┘ └───────────────────────────────┘ +/// ``` +pub(crate) struct SchemaAdapterStream { + input: SendableRecordBatchStream, + /// Output schema of this stream + /// The schema of `input` is always a subset of output_schema + output_schema: SchemaRef, + mappings: Vec, + /// metrics to record execution + baseline_metrics: BaselineMetrics, +} + +impl std::fmt::Debug for SchemaAdapterStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SchemaAdapterStream") + .field("input", &"(OPAQUE STREAM)") + .field("output_schema", &self.output_schema) + .field("mappings", &self.mappings) + .finish() + } +} + +impl SchemaAdapterStream { + /// Try to create a new adapter stream that produces batches with + /// the specified output schema. + /// + /// Virtual columns that contain constant values may be added via `virtual_columns`. Note that these columns MUST + /// NOT appear in underlying stream, other wise this method will fail. + /// + /// Columns that appear neither within the underlying stream nor a specified as virtual are created as pure NULL + /// columns. Note that the column must be nullable for this to work. + /// + /// If the underlying stream produces columns that DO NOT appear + /// in the output schema, or are different types than the output + /// schema, an error will be produced. + pub(crate) fn try_new( + input: SendableRecordBatchStream, + output_schema: SchemaRef, + virtual_columns: &HashMap<&str, ScalarValue>, + baseline_metrics: BaselineMetrics, + ) -> Result { + // record this setup time + let timer = baseline_metrics.elapsed_compute().timer(); + + let input_schema = input.schema(); + + // Figure out how to compute each column in the output + let mappings = output_schema + .fields() + .iter() + .map(|output_field| { + let input_field_index = input_schema + .fields() + .iter() + .enumerate() + .find(|(_, input_field)| output_field.name() == input_field.name()) + .map(|(idx, _)| idx); + + if let Some(input_field_index) = input_field_index { + ColumnMapping::FromInput(input_field_index) + } else if let Some(value) = virtual_columns.get(output_field.name().as_str()) { + ColumnMapping::Virtual(value.clone()) + } else { + ColumnMapping::MakeNull(output_field.data_type().clone()) + } + }) + .collect::>(); + + // Verify the mappings match the output type + for (output_index, mapping) in mappings.iter().enumerate() { + let output_field = output_schema.field(output_index); + + match mapping { + ColumnMapping::FromInput(input_index) => { + let input_field = input_schema.field(*input_index); + if input_field.data_type() != output_field.data_type() { + return InternalDataTypeMismatchSnafu { + input_field_name: input_field.name(), + input_field_type: input_field.data_type().clone(), + output_field_name: output_field.name(), + output_field_type: output_field.data_type().clone(), + } + .fail(); + } + + if virtual_columns.contains_key(input_field.name().as_str()) { + return InternalColumnBothInInputAndVirtualSnafu { + field_name: input_field.name().clone(), + } + .fail(); + } + } + ColumnMapping::MakeNull(_) => { + if !output_field.is_nullable() { + return InternalColumnNotNullableSnafu { + field_name: output_field.name().clone(), + output_field_type: output_field.data_type().clone(), + } + .fail(); + } + } + ColumnMapping::Virtual(value) => { + let data_type = value.data_type(); + if &data_type != output_field.data_type() { + return InternalDataTypeMismatchForVirtualSnafu { + field_type: data_type, + output_field_name: output_field.name(), + output_field_type: output_field.data_type().clone(), + } + .fail(); + } + } + } + } + + timer.done(); + Ok(Self { + input, + output_schema, + mappings, + baseline_metrics, + }) + } + + /// Extends the record batch, if needed, so that it matches the schema + fn extend_batch(&self, batch: RecordBatch) -> Result { + let output_columns = self + .mappings + .iter() + .map(|mapping| match mapping { + ColumnMapping::FromInput(input_index) => Ok(Arc::clone(batch.column(*input_index))), + ColumnMapping::MakeNull(data_type) => { + Ok(new_null_array(data_type, batch.num_rows())) + } + ColumnMapping::Virtual(value) => value.to_array_of_size(batch.num_rows()), + }) + .collect::, DataFusionError>>()?; + + Ok(RecordBatch::try_new( + Arc::clone(&self.output_schema), + output_columns, + )?) + } +} + +impl RecordBatchStream for SchemaAdapterStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.output_schema) + } +} + +impl Stream for SchemaAdapterStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + ctx: &mut Context<'_>, + ) -> Poll> { + // the Poll result is an Opton> so we need a few + // layers of maps to get at the actual batch, if any + let poll = self.input.as_mut().poll_next(ctx).map(|maybe_result| { + maybe_result.map(|batch| batch.and_then(|batch| self.extend_batch(batch))) + }); + self.baseline_metrics.record_poll(poll) + } + + // TODO is there a useful size_hint to pass? +} + +/// Describes how to create column in the output. +#[derive(Debug)] +enum ColumnMapping { + /// Output column is found at `` column of the input schema + FromInput(usize), + + /// Output colum should be synthesized with nulls of the specified type + MakeNull(DataType), + + /// Create virtual chunk column + Virtual(ScalarValue), +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + use arrow::{ + array::{ArrayRef, Int32Array, StringArray}, + datatypes::{Field, Schema}, + record_batch::RecordBatch, + }; + use arrow_util::assert_batches_eq; + use datafusion::physical_plan::{common::collect, metrics::ExecutionPlanMetricsSet}; + use datafusion_util::stream_from_batch; + use test_helpers::assert_contains; + + #[tokio::test] + async fn same_input_and_output() { + let batch = make_batch(); + + let output_schema = batch.schema(); + let input_stream = stream_from_batch(batch.schema(), batch); + let adapter_stream = SchemaAdapterStream::try_new( + input_stream, + output_schema, + &Default::default(), + baseline_metrics(), + ) + .unwrap(); + + let output = collect(Box::pin(adapter_stream)) + .await + .expect("Running plan"); + let expected = vec![ + "+---+---+-----+", + "| a | b | c |", + "+---+---+-----+", + "| 1 | 4 | foo |", + "| 2 | 5 | bar |", + "| 3 | 6 | baz |", + "+---+---+-----+", + ]; + assert_batches_eq!(&expected, &output); + } + + #[tokio::test] + async fn input_different_order_than_output() { + let batch = make_batch(); + // input has columns in different order than desired output + + let output_schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Utf8, false), + Field::new("a", DataType::Int32, false), + ])); + let input_stream = stream_from_batch(batch.schema(), batch); + let adapter_stream = SchemaAdapterStream::try_new( + input_stream, + output_schema, + &Default::default(), + baseline_metrics(), + ) + .unwrap(); + + let output = collect(Box::pin(adapter_stream)) + .await + .expect("Running plan"); + let expected = vec![ + "+---+-----+---+", + "| b | c | a |", + "+---+-----+---+", + "| 4 | foo | 1 |", + "| 5 | bar | 2 |", + "| 6 | baz | 3 |", + "+---+-----+---+", + ]; + assert_batches_eq!(&expected, &output); + } + + #[tokio::test] + async fn input_subset_of_output() { + let batch = make_batch(); + // input has subset of columns of the desired otuput. d and e are not present + let output_schema = Arc::new(Schema::new(vec![ + Field::new("c", DataType::Utf8, false), + Field::new("e", DataType::Float64, true), + Field::new("b", DataType::Int32, false), + Field::new("d", DataType::Float32, true), + Field::new("f", DataType::Utf8, true), + Field::new("g", DataType::Int32, false), + Field::new("h", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + ])); + let input_stream = stream_from_batch(batch.schema(), batch); + let adapter_stream = SchemaAdapterStream::try_new( + input_stream, + output_schema, + &HashMap::from([ + ("f", ScalarValue::from("xxx")), + ("g", ScalarValue::from(1i32)), + ("h", ScalarValue::from(1i32)), + ]), + baseline_metrics(), + ) + .unwrap(); + + let output = collect(Box::pin(adapter_stream)) + .await + .expect("Running plan"); + let expected = vec![ + "+-----+---+---+---+-----+---+---+---+", + "| c | e | b | d | f | g | h | a |", + "+-----+---+---+---+-----+---+---+---+", + "| foo | | 4 | | xxx | 1 | 1 | 1 |", + "| bar | | 5 | | xxx | 1 | 1 | 2 |", + "| baz | | 6 | | xxx | 1 | 1 | 3 |", + "+-----+---+---+---+-----+---+---+---+", + ]; + assert_batches_eq!(&expected, &output); + } + + #[tokio::test] + async fn input_superset_of_columns() { + let batch = make_batch(); + + // No such column "b" in output -- column would be lost + let output_schema = Arc::new(Schema::new(vec![ + Field::new("c", DataType::Utf8, false), + Field::new("a", DataType::Int32, false), + ])); + let input_stream = stream_from_batch(batch.schema(), batch); + let adapter_stream = SchemaAdapterStream::try_new( + input_stream, + output_schema, + &Default::default(), + baseline_metrics(), + ) + .unwrap(); + + let output = collect(Box::pin(adapter_stream)) + .await + .expect("Running plan"); + let expected = vec![ + "+-----+---+", + "| c | a |", + "+-----+---+", + "| foo | 1 |", + "| bar | 2 |", + "| baz | 3 |", + "+-----+---+", + ]; + assert_batches_eq!(&expected, &output); + } + + #[tokio::test] + async fn input_has_different_type() { + let batch = make_batch(); + + // column c has string type in input, output asks float32 + let output_schema = Arc::new(Schema::new(vec![ + Field::new("c", DataType::Float32, false), + Field::new("b", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + ])); + let input_stream = stream_from_batch(batch.schema(), batch); + let res = SchemaAdapterStream::try_new( + input_stream, + output_schema, + &Default::default(), + baseline_metrics(), + ); + + assert_contains!(res.unwrap_err().to_string(), "input field 'c' had type 'Utf8' which is different than output field 'c' which had type 'Float32'"); + } + + #[tokio::test] + async fn virtual_col_has_wrong_type() { + let batch = make_batch(); + + let output_schema = Arc::new(Schema::new(vec![ + Field::new("c", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), + Field::new("d", DataType::UInt8, false), + Field::new("a", DataType::Int32, false), + ])); + let input_stream = stream_from_batch(batch.schema(), batch); + let res = SchemaAdapterStream::try_new( + input_stream, + output_schema, + &HashMap::from([("d", ScalarValue::from(1u32))]), + baseline_metrics(), + ); + + assert_contains!(res.unwrap_err().to_string(), "creating virtual value of type 'UInt32' which is different than output field 'd' which had type 'UInt8'"); + } + + #[tokio::test] + async fn virtual_col_also_in_input() { + let batch = make_batch(); + + let output_schema = Arc::new(Schema::new(vec![ + Field::new("c", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), + Field::new("d", DataType::Utf8, false), + Field::new("a", DataType::Int32, false), + ])); + let input_stream = stream_from_batch(batch.schema(), batch); + let res = SchemaAdapterStream::try_new( + input_stream, + output_schema, + &HashMap::from([ + ("a", ScalarValue::from(1i32)), + ("d", ScalarValue::from("foo")), + ]), + baseline_metrics(), + ); + + assert_contains!(res.unwrap_err().to_string(), "the field 'a' is specified within the input and as a virtual column, don't know which one to choose"); + } + + #[tokio::test] + async fn null_non_nullable_column() { + let batch = make_batch(); + + let output_schema = Arc::new(Schema::new(vec![ + Field::new("c", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + Field::new("d", DataType::Utf8, false), + ])); + let input_stream = stream_from_batch(batch.schema(), batch); + let res = SchemaAdapterStream::try_new( + input_stream, + output_schema, + &Default::default(), + baseline_metrics(), + ); + + assert_contains!(res.unwrap_err().to_string(), "field 'd' had output type 'Utf8' and should be a NULL column but the field is flagged as 'not null'"); + } + + // input has different column types than desired output + + fn make_batch() -> RecordBatch { + let col_a = Arc::new(Int32Array::from(vec![1, 2, 3])); + let col_b = Arc::new(Int32Array::from(vec![4, 5, 6])); + let col_c = Arc::new(StringArray::from(vec!["foo", "bar", "baz"])); + + RecordBatch::try_from_iter(vec![("a", col_a as ArrayRef), ("b", col_b), ("c", col_c)]) + .unwrap() + } + + /// Create a BaselineMetrics object for testing + fn baseline_metrics() -> BaselineMetrics { + BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0) + } +} diff --git a/iox_query/src/provider/deduplicate.rs b/iox_query/src/provider/deduplicate.rs new file mode 100644 index 0000000..45c0250 --- /dev/null +++ b/iox_query/src/provider/deduplicate.rs @@ -0,0 +1,1238 @@ +//! Implemention of DeduplicateExec operator (resolves primary key conflicts) plumbing and tests +mod algo; + +use std::{collections::HashSet, fmt, sync::Arc}; + +use arrow::{error::ArrowError, record_batch::RecordBatch}; +use datafusion_util::{watch::WatchedTask, AdapterStream}; + +use crate::CHUNK_ORDER_COLUMN_NAME; + +use self::algo::get_col_name; +pub use self::algo::RecordBatchDeduplicator; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::{ + error::{DataFusionError, Result}, + execution::context::TaskContext, + physical_expr::PhysicalSortRequirement, + physical_plan::{ + expressions::{Column, PhysicalSortExpr}, + metrics::{ + self, BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, RecordOutput, + }, + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + SendableRecordBatchStream, Statistics, + }, +}; +use futures::StreamExt; +use observability_deps::tracing::{debug, trace}; +use tokio::sync::mpsc; + +/// # DeduplicateExec +/// +/// This operator takes an input stream of RecordBatches that is +/// already sorted on "sort_key" and applies IOx specific deduplication +/// logic. +/// +/// The output is dependent on the order of the the input rows which +/// have the same key. +/// +/// Specifically, the value chosen for each non-sort_key column is the +/// "last" non-null value. This is used to model "upserts" when new +/// rows with the same primary key are inserted a second time to update +/// existing values. +/// +/// # Example +/// For example, given a sort key of (t1, t2) and the following input +/// (already sorted on t1 and t2): +/// +/// ```text +/// +----+----+----+----+ +/// | t1 | t2 | f1 | f2 | +/// +----+----+----+----+ +/// | a | x | 2 | | +/// | a | x | 2 | 1 | +/// | a | x | | 3 | +/// | a | y | 3 | 1 | +/// | b | y | 3 | | +/// | c | y | 1 | 1 | +/// +----+----+----+----+ +/// ``` +/// +/// This operator will produce the following output (note the values +/// chosen for (a, x)): +/// +/// ```text +/// +----+----+----+----+ +/// | t1 | t2 | f1 | f2 | +/// +----+----+----+----+ +/// | a | x | 2 | 3 | +/// | a | y | 3 | 1 | +/// | b | y | 3 | | +/// | c | y | 1 | 1 | +/// +----+----+----+----+ +/// ``` +/// +/// # Field Resolution (why the last non-null value?) +/// +/// The choice of the latest non-null value instead of the latest value is +/// subtle and thus we try to document the rationale here. It is a +/// consequence of the LineProtocol update model. +/// +/// Some observations about line protocol are: +/// +/// 1. Lines are treated as "UPSERT"s (aka updating any existing +/// values, possibly adding new fields) +/// +/// 2. Fields can not be removed or set to NULL via a line (So if a +/// field has a NULL value it means the user didn't provide a value +/// for that field) +/// +/// For example, this data (with a NULL for `f2`): +/// +/// ```text +/// t1 | f1 | f2 +/// ---+----+---- +/// a | 1 | 3 +// a | 2 | +/// ``` +/// +/// Would have come from line protocol like +/// ```text +/// m,t1=a f1=1,f2=3 +/// m,t1=a f1=3 +/// ``` +/// (note there was no value for f2 provided in the second line, it can +/// be read as "upsert value of f1=3, the value of f2 is not modified). +/// +/// Thus it would not be correct to take the latest value from f2 +/// (NULL) as in the source input the field's value was not provided. +#[derive(Debug)] +pub struct DeduplicateExec { + input: Arc, + sort_keys: Vec, + input_order: Vec, + use_chunk_order_col: bool, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, +} + +impl DeduplicateExec { + pub fn new( + input: Arc, + sort_keys: Vec, + use_chunk_order_col: bool, + ) -> Self { + let mut input_order = sort_keys.clone(); + if use_chunk_order_col { + input_order.push(PhysicalSortExpr { + expr: Arc::new( + Column::new_with_schema(CHUNK_ORDER_COLUMN_NAME, &input.schema()) + .expect("input has chunk order col"), + ), + options: Default::default(), + }) + } + Self { + input, + sort_keys, + input_order, + use_chunk_order_col, + metrics: ExecutionPlanMetricsSet::new(), + } + } + + pub fn sort_keys(&self) -> &[PhysicalSortExpr] { + &self.sort_keys + } + + /// Combination of all columns within the sort key and potentially the chunk order column. + pub fn sort_columns(&self) -> HashSet<&str> { + self.input_order + .iter() + .map(|sk| get_col_name(sk.expr.as_ref())) + .collect() + } + + pub fn use_chunk_order_col(&self) -> bool { + self.use_chunk_order_col + } +} + +#[derive(Debug)] +struct DeduplicateMetrics { + baseline_metrics: BaselineMetrics, + num_dupes: metrics::Count, +} + +impl DeduplicateMetrics { + fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + Self { + baseline_metrics: BaselineMetrics::new(metrics, partition), + num_dupes: MetricBuilder::new(metrics).counter("num_dupes", partition), + } + } +} + +impl ExecutionPlan for DeduplicateExec { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> arrow::datatypes::SchemaRef { + self.input.schema() + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(1) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + trace!("Deduplicate output ordering: {:?}", self.sort_keys); + Some(&self.sort_keys) + } + + fn required_input_ordering(&self) -> Vec>> { + vec![Some(PhysicalSortRequirement::from_sort_exprs( + &self.input_order, + ))] + } + + fn maintains_input_order(&self) -> Vec { + vec![true] + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] + } + + fn children(&self) -> Vec> { + vec![Arc::clone(&self.input)] + } + + fn equivalence_properties(&self) -> EquivalenceProperties { + // deduplicate does not change the equivalence properties + self.input.equivalence_properties() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion::error::Result> { + assert_eq!(children.len(), 1); + let input = Arc::clone(&children[0]); + Ok(Arc::new(Self::new( + input, + self.sort_keys.clone(), + self.use_chunk_order_col, + ))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + trace!(partition, "Start DeduplicationExec::execute"); + + if partition != 0 { + return Err(DataFusionError::Internal( + "DeduplicateExec only supports a single input stream".to_string(), + )); + } + let deduplicate_metrics = DeduplicateMetrics::new(&self.metrics, partition); + + let input_stream = self.input.execute(0, context)?; + + // the deduplication is performed in a separate task which is + // then sent via a channel to the output + let (tx, rx) = mpsc::channel(1); + + let fut = deduplicate( + input_stream, + self.sort_keys.clone(), + tx.clone(), + deduplicate_metrics, + ); + + // A second task watches the output of the worker task and reports errors + let handle = WatchedTask::new(fut, vec![tx], "deduplicate batches"); + + debug!( + partition, + "End building stream for DeduplicationExec::execute" + ); + + Ok(AdapterStream::adapt(self.schema(), rx, handle)) + } + + fn required_input_distribution(&self) -> Vec { + // For now use a single input -- it might be helpful + // eventually to deduplicate in parallel by hash partitioning + // the inputs (based on sort keys) + vec![Distribution::SinglePartition] + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + // use a guess from our input but they are NOT exact + Ok(self.input.statistics()?.into_inexact()) + } +} + +impl DisplayAs for DeduplicateExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let expr: Vec = self.sort_keys.iter().map(|e| e.to_string()).collect(); + write!(f, "DeduplicateExec: [{}]", expr.join(",")) + } + } + } +} + +async fn deduplicate( + mut input_stream: SendableRecordBatchStream, + sort_keys: Vec, + tx: mpsc::Sender>, + deduplicate_metrics: DeduplicateMetrics, +) -> Result<(), DataFusionError> { + let DeduplicateMetrics { + baseline_metrics, + num_dupes, + } = deduplicate_metrics; + + let elapsed_compute = baseline_metrics.elapsed_compute(); + let mut deduplicator = RecordBatchDeduplicator::new(sort_keys, num_dupes, None); + + // Stream input through the indexer + while let Some(batch) = input_stream.next().await { + let batch = batch?; + + // First check if this batch has same sort key with its previous batch + let timer = elapsed_compute.timer(); + if let Some(last_batch) = deduplicator + .last_batch_with_no_same_sort_key(&batch) + .record_output(&baseline_metrics) + { + timer.done(); + // No, different sort key, so send the last batch downstream first + if last_batch.num_rows() > 0 { + tx.send(Ok(last_batch)) + .await + .map_err(|e| ArrowError::from_external_error(Box::new(e)))?; + } + } else { + timer.done() + } + + // deduplicate data of the batch + let timer = elapsed_compute.timer(); + let output_batch = deduplicator.push(batch)?.record_output(&baseline_metrics); + timer.done(); + if output_batch.num_rows() > 0 { + tx.send(Ok(output_batch)) + .await + .map_err(|e| ArrowError::from_external_error(Box::new(e)))?; + } + } + debug!("before sending the left over batch"); + + // send any left over batch + let timer = elapsed_compute.timer(); + if let Some(output_batch) = deduplicator.finish()?.record_output(&baseline_metrics) { + timer.done(); + if output_batch.num_rows() > 0 { + tx.send(Ok(output_batch)) + .await + .map_err(|e| ArrowError::from_external_error(Box::new(e)))?; + } + } else { + timer.done() + } + debug!("done sending the left over batch"); + + Ok(()) +} + +#[cfg(test)] +mod test { + use arrow::compute::SortOptions; + use arrow::datatypes::{Int32Type, SchemaRef}; + use arrow::{ + array::{ArrayRef, Float64Array, StringArray, TimestampNanosecondArray}, + record_batch::RecordBatch, + }; + use arrow_util::assert_batches_eq; + use datafusion::physical_plan::{expressions::col, memory::MemoryExec}; + use datafusion_util::test_collect; + + use super::*; + use arrow::array::{DictionaryArray, Int64Array}; + use schema::TIME_DATA_TIMEZONE; + use std::iter::FromIterator; + + #[tokio::test] + async fn test_single_tag() { + // input: + // t1 | f1 | f2 + // ---+----+---- + // a | 1 | + // a | 2 | 3 + // a | | 4 + // b | 5 | 6 + // c | 7 | + // c | | + // c | | 8 + // + // expected output: + // + // t1 | f1 | f2 + // ---+----+---- + // a | 2 | 4 + // b | 5 | 6 + // c | 7 | 8 + + let t1 = StringArray::from(vec![ + Some("a"), + Some("a"), + Some("a"), + Some("b"), + Some("c"), + Some("c"), + Some("c"), + ]); + let f1 = Float64Array::from(vec![ + Some(1.0), + Some(2.0), + None, + Some(5.0), + Some(7.0), + None, + None, + ]); + let f2 = Float64Array::from(vec![ + None, + Some(3.0), + Some(4.0), + Some(6.0), + None, + None, + Some(8.0), + ]); + + let batch = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("f1", Arc::new(f1) as ArrayRef), + ("f2", Arc::new(f2) as ArrayRef), + ]) + .unwrap(); + + let sort_keys = vec![PhysicalSortExpr { + expr: col("t1", &batch.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + + let results = dedupe(vec![batch], sort_keys).await; + + let expected = vec![ + "+----+-----+-----+", + "| t1 | f1 | f2 |", + "+----+-----+-----+", + "| a | 2.0 | 4.0 |", + "| b | 5.0 | 6.0 |", + "| c | 7.0 | 8.0 |", + "+----+-----+-----+", + ]; + assert_batches_eq!(&expected, &results.output); + } + + #[tokio::test] + async fn test_with_timestamp() { + // input: + // f1 | f2 | time + // ---+----+------ + // 1 | | 100 + // | 3 | 100 + // + // expected output: + // + // f1 | f2 | time + // ---+----+------- + // 1 | 3 | 100 + let f1 = Float64Array::from(vec![Some(1.0), None]); + let f2 = Float64Array::from(vec![None, Some(3.0)]); + + let time = TimestampNanosecondArray::from(vec![Some(100), Some(100)]) + .with_timezone_opt(TIME_DATA_TIMEZONE()); + + let batch = RecordBatch::try_from_iter(vec![ + ("f1", Arc::new(f1) as ArrayRef), + ("f2", Arc::new(f2) as ArrayRef), + ("time", Arc::new(time) as ArrayRef), + ]) + .unwrap(); + + let sort_keys = vec![PhysicalSortExpr { + expr: col("time", &batch.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + + let results = dedupe(vec![batch], sort_keys).await; + + let expected = vec![ + "+-----+-----+--------------------------------+", + "| f1 | f2 | time |", + "+-----+-----+--------------------------------+", + "| 1.0 | 3.0 | 1970-01-01T00:00:00.000000100Z |", + "+-----+-----+--------------------------------+", + ]; + assert_batches_eq!(&expected, &results.output); + } + + #[tokio::test] + async fn test_multi_tag() { + // input: + // t1 | t2 | f1 | f2 + // ---+----+----+---- + // a | b | 1 | + // a | b | 2 | 3 + // a | b | | 4 + // a | z | 5 | + // b | b | 6 | + // b | c | 7 | 6 + // c | c | 8 | + // d | b | | 9 + // e | | 10 | 11 + // e | | 12 | + // | f | 13 | + // | f | | 14 + // + // expected output: + // t1 | t2 | f1 | f2 + // ---+----+----+---- + // a | b | 2 | 4 + // a | z | 5 | + // b | b | 6 | + // b | c | 7 | 6 + // c | c | 8 | + // d | b | | 9 + // e | | 12 | 11 + // | f | 13 | 14 + + let t1 = StringArray::from(vec![ + Some("a"), + Some("a"), + Some("a"), + Some("a"), + Some("b"), + Some("b"), + Some("c"), + Some("d"), + Some("e"), + Some("e"), + None, + None, + ]); + + let t2 = StringArray::from(vec![ + Some("b"), + Some("b"), + Some("b"), + Some("z"), + Some("b"), + Some("c"), + Some("c"), + Some("b"), + None, + None, + Some("f"), + Some("f"), + ]); + + let f1 = Float64Array::from(vec![ + Some(1.0), + Some(2.0), + None, + Some(5.0), + Some(6.0), + Some(7.0), + Some(8.0), + None, + Some(10.0), + Some(12.0), + Some(13.0), + None, + ]); + + let f2 = Float64Array::from(vec![ + None, + Some(3.0), + Some(4.0), + None, + None, + Some(6.0), + None, + Some(9.0), + Some(11.0), + None, + None, + Some(14.0), + ]); + + let batch = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("t2", Arc::new(t2) as ArrayRef), + ("f1", Arc::new(f1) as ArrayRef), + ("f2", Arc::new(f2) as ArrayRef), + ]) + .unwrap(); + + let sort_keys = vec![ + PhysicalSortExpr { + expr: col("t1", &batch.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + PhysicalSortExpr { + expr: col("t2", &batch.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + ]; + + let results = dedupe(vec![batch], sort_keys).await; + + let expected = vec![ + "+----+----+------+------+", + "| t1 | t2 | f1 | f2 |", + "+----+----+------+------+", + "| a | b | 2.0 | 4.0 |", + "| a | z | 5.0 | |", + "| b | b | 6.0 | |", + "| b | c | 7.0 | 6.0 |", + "| c | c | 8.0 | |", + "| d | b | | 9.0 |", + "| e | | 12.0 | 11.0 |", + "| | f | 13.0 | 14.0 |", + "+----+----+------+------+", + ]; + assert_batches_eq!(&expected, &results.output); + } + + #[tokio::test] + async fn test_string_with_timestamp() { + // input: + // s | i | time + // -------+----+------ + // "cat" | | 100 + // | 3 | 100 + // | 4 | 200 + // "dog" | | 200 + // + // expected output: + // + // s | i | time + // -------+----+------- + // "cat" | 3 | 100 + // "dog" | 4 | 200 + let s = StringArray::from(vec![Some("cat"), None, None, Some("dog")]); + + let i = Int64Array::from(vec![None, Some(3), Some(4), None]); + + let time = TimestampNanosecondArray::from(vec![Some(100), Some(100), Some(200), Some(200)]); + + let batch = RecordBatch::try_from_iter(vec![ + ("s", Arc::new(s) as ArrayRef), + ("i", Arc::new(i) as ArrayRef), + ("time", Arc::new(time) as ArrayRef), + ]) + .unwrap(); + + let sort_keys = vec![PhysicalSortExpr { + expr: col("time", &batch.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + + let results = dedupe(vec![batch], sort_keys).await; + + let expected = vec![ + "+-----+---+--------------------------------+", + "| s | i | time |", + "+-----+---+--------------------------------+", + "| cat | 3 | 1970-01-01T00:00:00.000000100Z |", + "| dog | 4 | 1970-01-01T00:00:00.000000200Z |", + "+-----+---+--------------------------------+", + ]; + assert_batches_eq!(&expected, &results.output); + } + + #[tokio::test] + async fn test_last_is_null_with_timestamp() { + // input: + // s | i | time + // -------+----+------ + // "cat" | | 1639612800000000000 + // | 10 | 1639612800000000000 + // + // expected output: + // + // s | i | time + // -------+----+------- + // "cat" | 10 | 1639612800000000000 + let s = StringArray::from(vec![Some("cat"), None]); + + let i = Int64Array::from(vec![None, Some(10)]); + + let time = TimestampNanosecondArray::from(vec![ + Some(1639612800000000000), + Some(1639612800000000000), + ]); + + let batch = RecordBatch::try_from_iter(vec![ + ("s", Arc::new(s) as ArrayRef), + ("i", Arc::new(i) as ArrayRef), + ("time", Arc::new(time) as ArrayRef), + ]) + .unwrap(); + + let sort_keys = vec![PhysicalSortExpr { + expr: col("time", &batch.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + + let results = dedupe(vec![batch], sort_keys).await; + + let expected = vec![ + "+-----+----+----------------------+", + "| s | i | time |", + "+-----+----+----------------------+", + "| cat | 10 | 2021-12-16T00:00:00Z |", + "+-----+----+----------------------+", + ]; + assert_batches_eq!(&expected, &results.output); + } + + #[tokio::test] + async fn test_multi_record_batch() { + // input: + // t1 | t2 | f1 | f2 + // ---+----+----+---- + // a | b | 1 | 2 + // a | c | 3 | + // a | c | 4 | 5 + // ====(next batch)==== + // a | c | | 6 + // b | d | 7 | 8 + + // + // expected output: + // t1 | t2 | f1 | f2 + // ---+----+----+---- + // a | b | 1 | 2 + // a | c | 4 | 6 + // b | d | 7 | 8 + + let t1 = StringArray::from(vec![Some("a"), Some("a"), Some("a")]); + + let t2 = StringArray::from(vec![Some("b"), Some("c"), Some("c")]); + + let f1 = Float64Array::from(vec![Some(1.0), Some(3.0), Some(4.0)]); + + let f2 = Float64Array::from(vec![Some(2.0), None, Some(5.0)]); + + let batch1 = RecordBatch::try_from_iter_with_nullable(vec![ + ("t1", Arc::new(t1) as ArrayRef, true), + ("t2", Arc::new(t2) as ArrayRef, true), + ("f1", Arc::new(f1) as ArrayRef, true), + ("f2", Arc::new(f2) as ArrayRef, true), + ]) + .unwrap(); + + let t1 = StringArray::from(vec![Some("a"), Some("b")]); + + let t2 = StringArray::from(vec![Some("c"), Some("d")]); + + let f1 = Float64Array::from(vec![None, Some(7.0)]); + + let f2 = Float64Array::from(vec![Some(6.0), Some(8.0)]); + + let batch2 = RecordBatch::try_from_iter_with_nullable(vec![ + ("t1", Arc::new(t1) as ArrayRef, true), + ("t2", Arc::new(t2) as ArrayRef, true), + ("f1", Arc::new(f1) as ArrayRef, true), + ("f2", Arc::new(f2) as ArrayRef, true), + ]) + .unwrap(); + + let sort_keys = vec![ + PhysicalSortExpr { + expr: col("t1", &batch2.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + PhysicalSortExpr { + expr: col("t2", &batch2.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + ]; + + let results = dedupe(vec![batch1, batch2], sort_keys).await; + + let expected = vec![ + "+----+----+-----+-----+", + "| t1 | t2 | f1 | f2 |", + "+----+----+-----+-----+", + "| a | b | 1.0 | 2.0 |", + "| a | c | 4.0 | 6.0 |", + "| b | d | 7.0 | 8.0 |", + "+----+----+-----+-----+", + ]; + assert_batches_eq!(&expected, &results.output); + // 5 rows in initial input, 3 rows in output ==> 2 dupes + assert_eq!(results.num_dupes(), 5 - 3); + } + + #[tokio::test] + async fn test_no_dupes() { + // special case test for data without duplicates (fast path) + // input: + // t1 | f1 + // ---+---- + // a | 1 + // ====(next batch)==== + // b | 2 + // + // expected output: + // + // t1 | f1 + // ---+---- + // a | 1 + // b | 2 + + let t1 = StringArray::from(vec![Some("a")]); + let f1 = Float64Array::from(vec![Some(1.0)]); + + let batch1 = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("f1", Arc::new(f1) as ArrayRef), + ]) + .unwrap(); + + let t1 = StringArray::from(vec![Some("b")]); + let f1 = Float64Array::from(vec![Some(2.0)]); + + let batch2 = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("f1", Arc::new(f1) as ArrayRef), + ]) + .unwrap(); + + let sort_keys = vec![PhysicalSortExpr { + expr: col("t1", &batch2.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + + let results = dedupe(vec![batch1, batch2], sort_keys).await; + + let expected = vec![ + "+----+-----+", + "| t1 | f1 |", + "+----+-----+", + "| a | 1.0 |", + "| b | 2.0 |", + "+----+-----+", + ]; + assert_batches_eq!(&expected, &results.output); + + // also validate there were no dupes detected + assert_eq!(results.num_dupes(), 0); + } + + #[tokio::test] + async fn test_single_pk() { + // test boundary condition + + // input: + // t1 | f1 | f2 + // ---+----+---- + // a | 1 | 2 + // a | 3 | 4 + // + // expected output: + // + // t1 | f1 | f2 + // ---+----+---- + // a | 3 | 4 + + let t1 = StringArray::from(vec![Some("a"), Some("a")]); + let f1 = Float64Array::from(vec![Some(1.0), Some(3.0)]); + let f2 = Float64Array::from(vec![Some(2.0), Some(4.0)]); + + let batch = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("f1", Arc::new(f1) as ArrayRef), + ("f2", Arc::new(f2) as ArrayRef), + ]) + .unwrap(); + + let sort_keys = vec![PhysicalSortExpr { + expr: col("t1", &batch.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + + let results = dedupe(vec![batch], sort_keys).await; + + let expected = vec![ + "+----+-----+-----+", + "| t1 | f1 | f2 |", + "+----+-----+-----+", + "| a | 3.0 | 4.0 |", + "+----+-----+-----+", + ]; + assert_batches_eq!(&expected, &results.output); + } + + #[tokio::test] + async fn test_column_reorder() { + // test if they fields come before tags and tags not in right order + + // input: + // f1 | t2 | t1 + // ---+----+---- + // 1 | a | a + // 2 | a | a + // 3 | a | b + // 4 | b | b + // + // expected output: + // + // f1 | t2 | t1 + // ---+----+---- + // 2 | a | a + // 3 | a | b + // 4 | b | b + + let f1 = Float64Array::from(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)]); + let t2 = StringArray::from(vec![Some("a"), Some("a"), Some("a"), Some("b")]); + let t1 = StringArray::from(vec![Some("a"), Some("a"), Some("b"), Some("b")]); + + let batch = RecordBatch::try_from_iter(vec![ + ("f1", Arc::new(f1) as ArrayRef), + ("t2", Arc::new(t2) as ArrayRef), + ("t1", Arc::new(t1) as ArrayRef), + ]) + .unwrap(); + + let sort_keys = vec![ + PhysicalSortExpr { + expr: col("t1", &batch.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + PhysicalSortExpr { + expr: col("t2", &batch.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + ]; + + let results = dedupe(vec![batch], sort_keys).await; + + let expected = vec![ + "+-----+----+----+", + "| f1 | t2 | t1 |", + "+-----+----+----+", + "| 2.0 | a | a |", + "| 3.0 | a | b |", + "| 4.0 | b | b |", + "+-----+----+----+", + ]; + assert_batches_eq!(&expected, &results.output); + } + + #[tokio::test] + #[should_panic(expected = "This is the error")] + async fn test_input_error_propagated() { + // test that an error from the input gets to the output + + // input: + // t1 | f1 + // ---+---- + // a | 1 + // === next batch === + // (error) + + let t1 = StringArray::from(vec![Some("a")]); + let f1 = Float64Array::from(vec![Some(1.0)]); + + let batch = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("f1", Arc::new(f1) as ArrayRef), + ]) + .unwrap(); + + let schema = batch.schema(); + let batches = vec![ + Ok(batch), + Err(ArrowError::ComputeError("This is the error".to_string())), + ]; + + let input = Arc::new(DummyExec { + schema: Arc::clone(&schema), + batches, + }); + + let sort_keys = vec![PhysicalSortExpr { + expr: col("t1", &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + + let exec: Arc = Arc::new(DeduplicateExec::new(input, sort_keys, false)); + test_collect(exec).await; + } + + #[tokio::test] + async fn test_dictionary() { + let t1 = DictionaryArray::::from_iter(vec![Some("a"), Some("a"), Some("b")]); + let t2 = DictionaryArray::::from_iter(vec![Some("b"), Some("c"), Some("c")]); + let f1 = Float64Array::from(vec![Some(1.0), Some(3.0), Some(4.0)]); + let f2 = Float64Array::from(vec![Some(2.0), None, Some(5.0)]); + + let batch1 = RecordBatch::try_from_iter_with_nullable(vec![ + ("t1", Arc::new(t1) as ArrayRef, true), + ("t2", Arc::new(t2) as ArrayRef, true), + ("f1", Arc::new(f1) as ArrayRef, true), + ("f2", Arc::new(f2) as ArrayRef, true), + ]) + .unwrap(); + + let t1 = DictionaryArray::::from_iter(vec![Some("b"), Some("c")]); + let t2 = DictionaryArray::::from_iter(vec![Some("c"), Some("d")]); + let f1 = Float64Array::from(vec![None, Some(7.0)]); + let f2 = Float64Array::from(vec![Some(6.0), Some(8.0)]); + + let batch2 = RecordBatch::try_from_iter_with_nullable(vec![ + ("t1", Arc::new(t1) as ArrayRef, true), + ("t2", Arc::new(t2) as ArrayRef, true), + ("f1", Arc::new(f1) as ArrayRef, true), + ("f2", Arc::new(f2) as ArrayRef, true), + ]) + .unwrap(); + + let sort_keys = vec![ + PhysicalSortExpr { + expr: col("t1", &batch1.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + PhysicalSortExpr { + expr: col("t2", &batch1.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + ]; + + let results = dedupe(vec![batch1, batch2], sort_keys).await; + + let cols: Vec<_> = results + .output + .iter() + .map(|batch| { + batch + .column(batch.schema().column_with_name("t1").unwrap().0) + .as_any() + .downcast_ref::>() + .unwrap() + }) + .collect(); + + // Should produce optimised dictionaries + // The batching is not important + assert_eq!(cols.len(), 3); + assert_eq!(cols[0].keys().len(), 2); + assert_eq!(cols[0].values().len(), 1); // "a" + assert_eq!(cols[1].keys().len(), 1); + assert_eq!(cols[1].values().len(), 1); // "b" + assert_eq!(cols[2].keys().len(), 1); + assert_eq!(cols[2].values().len(), 1); // "c" + + let expected = vec![ + "+----+----+-----+-----+", + "| t1 | t2 | f1 | f2 |", + "+----+----+-----+-----+", + "| a | b | 1.0 | 2.0 |", + "| a | c | 3.0 | |", + "| b | c | 4.0 | 6.0 |", + "| c | d | 7.0 | 8.0 |", + "+----+----+-----+-----+", + ]; + assert_batches_eq!(&expected, &results.output); + // 5 rows in initial input, 4 rows in output ==> 1 dupes + assert_eq!(results.num_dupes(), 5 - 4); + } + + struct TestResults { + output: Vec, + exec: Arc, + } + + impl TestResults { + /// return the number of duplicates this deduplicator detected + fn num_dupes(&self) -> usize { + let metrics = self.exec.metrics().unwrap(); + + let metrics = metrics + .iter() + .filter(|m| m.value().name() == "num_dupes") + .collect::>(); + + assert_eq!( + metrics.len(), + 1, + "expected only one duplicate metric, found {metrics:?}" + ); + metrics[0].value().as_usize() + } + } + + /// Run the input through the deduplicator and return results + async fn dedupe(input: Vec, sort_keys: Vec) -> TestResults { + test_helpers::maybe_start_logging(); + + // Setup in memory stream + let schema = input[0].schema(); + let projection = None; + let input = Arc::new(MemoryExec::try_new(&[input], schema, projection).unwrap()); + + // Create and run the deduplicator + let exec = Arc::new(DeduplicateExec::new(input, sort_keys, false)); + let output = test_collect(Arc::clone(&exec) as Arc).await; + + TestResults { output, exec } + } + + /// A PhysicalPlan that sends a specific set of + /// Result for testing. + #[derive(Debug)] + struct DummyExec { + schema: SchemaRef, + batches: Vec>, + } + + impl ExecutionPlan for DummyExec { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn output_partitioning(&self) -> Partitioning { + unimplemented!(); + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + unimplemented!() + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + unimplemented!() + } + + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> Result { + assert_eq!(partition, 0); + + debug!(partition, "Start DummyExec::execute"); + + // queue them all up + let (tx, rx) = mpsc::unbounded_channel(); + + // queue up all the results + let batches: Vec<_> = self + .batches + .iter() + .map(|r| match r { + Ok(batch) => Ok(batch.clone()), + Err(e) => Err(DataFusionError::External(e.to_string().into())), + }) + .collect(); + let tx_captured = tx.clone(); + let fut = async move { + for r in batches { + tx_captured.send(r).unwrap(); + } + + Ok(()) + }; + let handle = WatchedTask::new(fut, vec![tx], "dummy send"); + + debug!(partition, "End DummyExec::execute"); + Ok(AdapterStream::adapt_unbounded(self.schema(), rx, handle)) + } + + fn statistics(&self) -> Result { + // don't know anything about the statistics + Ok(Statistics::new_unknown(&self.schema())) + } + } + + impl DisplayAs for DummyExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "DummyExec") + } + } +} diff --git a/iox_query/src/provider/deduplicate/algo.rs b/iox_query/src/provider/deduplicate/algo.rs new file mode 100644 index 0000000..a4c24e6 --- /dev/null +++ b/iox_query/src/provider/deduplicate/algo.rs @@ -0,0 +1,841 @@ +//! Implementation of Deduplication algorithm + +use std::{cmp::Ordering, ops::Range, sync::Arc}; + +use arrow::{ + array::{ArrayRef, UInt64Array}, + compute::TakeOptions, + datatypes::{DataType, TimeUnit}, + error::Result as ArrowResult, + record_batch::RecordBatch, +}; + +use arrow_util::optimize::optimize_dictionaries; +use datafusion::physical_plan::{ + coalesce_batches::concat_batches, expressions::PhysicalSortExpr, metrics, PhysicalExpr, +}; +use observability_deps::tracing::{debug, trace}; + +// Handles the deduplication across potentially multiple +// [`RecordBatch`]es which are already sorted on a primary key, +// including primary keys which straddle RecordBatch boundaries +#[derive(Debug)] +pub struct RecordBatchDeduplicator { + sort_keys: Vec, + last_batch: Option, + num_dupes: metrics::Count, +} + +#[derive(Debug)] +struct DuplicateRanges { + /// `is_sort_key[col_idx] = true` if the the input column at + /// `col_idx` is present in sort keys + is_sort_key: Vec, + + /// ranges of row indices where the sort key columns have the + /// same values + ranges: Vec>, +} + +impl RecordBatchDeduplicator { + pub fn new( + sort_keys: Vec, + num_dupes: metrics::Count, + last_batch: Option, + ) -> Self { + Self { + sort_keys, + last_batch, + num_dupes, + } + } + + /// Push a new RecordBatch into the indexer. Returns a + /// deduplicated RecordBatch and remembers any currently opened + /// groups + pub fn push(&mut self, batch: RecordBatch) -> ArrowResult { + // If we had a previous batch of rows, add it in here + // + // Potential optimization would be to check if the sort key is actually the same + // for the first row in the new batch and skip this concat if that is the case + let batch = if let Some(last_batch) = self.last_batch.take() { + let schema = last_batch.schema(); + let row_count = last_batch.num_rows() + batch.num_rows(); + debug!(row_count, "Before concat_batches"); + let result = concat_batches(&schema, &[last_batch, batch], row_count)?; + debug!(row_count, "After concat_batches"); + result + } else { + batch + }; + + let mut dupe_ranges = self.compute_ranges(&batch)?; + trace!("Finish computing range"); + + // The last partition may span batches so we can't emit it + // until we have seen the next batch (or we are at end of + // stream) + let last_range = dupe_ranges.ranges.pop(); + + let output_record_batch = self.output_from_ranges(&batch, &dupe_ranges)?; + trace!( + num_rows = output_record_batch.num_rows(), + "Rows of ouput_record_batch" + ); + + // Now, save the last bit of the pk + if let Some(last_range) = last_range { + let len = last_range.end - last_range.start; + let last_batch = Self::slice_record_batch(&batch, last_range.start, len)?; + self.last_batch = Some(last_batch); + } + trace!("done pushing record batch into the indexer"); + + Ok(output_record_batch) + } + + /// Return last_batch if it does not overlap with the given batch + /// Note that since last_batch, if exists, will include at least one row and all of its rows will have the same key + pub fn last_batch_with_no_same_sort_key(&mut self, batch: &RecordBatch) -> Option { + // Take the previous batch, if any, out of it storage self.last_batch + if let Some(last_batch) = self.last_batch.take() { + // Build sorted columns for last_batch and current one + let schema = last_batch.schema(); + // is_sort_key[col_idx] = true if it is present in sort keys + let mut is_sort_key: Vec = vec![false; last_batch.columns().len()]; + let last_batch_key_columns = self + .sort_keys + .iter() + .map(|skey| { + // figure out the index of the key columns + let name = get_col_name(skey.expr.as_ref()); + let index = schema.index_of(name).unwrap(); + is_sort_key[index] = true; + + // Key column of last_batch of this index + let last_batch_array = last_batch.column(index); + if last_batch_array.len() == 0 { + panic!("Key column, {name}, of last_batch has no data"); + } + arrow::compute::SortColumn { + values: Arc::clone(last_batch_array), + options: Some(skey.options), + } + }) + .collect::>(); + + // Build sorted columns for current batch + // Schema of both batches are the same + let batch_key_columns = self + .sort_keys + .iter() + .map(|skey| { + // figure out the index of the key columns + let name = get_col_name(skey.expr.as_ref()); + let index = schema.index_of(name).unwrap(); + + // Key column of current batch of this index + let array = batch.column(index); + if array.len() == 0 { + panic!("Key column, {name}, of current batch has no data"); + } + arrow::compute::SortColumn { + values: Arc::clone(array), + options: Some(skey.options), + } + }) + .collect::>(); + + // Zip the 2 key sets of columns for comparison + let zipped = last_batch_key_columns.iter().zip(batch_key_columns.iter()); + + // Compare sort keys of the first row of the given batch the the last_batch + // Note that the batches are sorted and all rows of last_batch have the same sort keys so + // only need to compare last row of the last_batch with the first row of the current batch + let mut same = true; + for (l, r) in zipped { + let last_idx = l.values.len() - 1; + if (l.values.is_valid(last_idx), r.values.is_valid(0)) == (true, true) { + // Both have values, do the actual comparison + let c = + arrow::array::build_compare(l.values.as_ref(), r.values.as_ref()).unwrap(); + + match c(last_idx, 0) { + Ordering::Equal => {} + _ => { + same = false; + break; + } + } + } else { + // At least one of the value is invalid, consider they are different + same = false; + break; + } + } + + if same { + // The batches overlap and need to be concatinated + // So, store it back in self.last_batch for the concat_batches later + self.last_batch = Some(last_batch); + None + } else { + // The batches do not overlap, deduplicate and then return the last_batch to get sent downstream + + // Ranges of the batch needed for deduplication + // This last batch include only one range with all same key + let ranges = vec![ + Range { + start: 0, + end: last_batch.num_rows() + }; + 1 + ]; + let dupe_ranges = DuplicateRanges { + is_sort_key, + ranges, + }; + let dedup_last_batch = self.output_from_ranges(&last_batch, &dupe_ranges).unwrap(); + + Some(dedup_last_batch) + } + } else { + None + } + } + + /// Consume the indexer, returning any remaining record batches for output + pub fn finish(mut self) -> ArrowResult> { + self.last_batch + .take() + .map(|last_batch| { + let dupe_ranges = self.compute_ranges(&last_batch)?; + self.output_from_ranges(&last_batch, &dupe_ranges) + }) + .transpose() + } + + /// Computes the ranges where the sort key has the same values + fn compute_ranges(&self, batch: &RecordBatch) -> ArrowResult { + let schema = batch.schema(); + // is_sort_key[col_idx] = true if it is present in sort keys + let mut is_sort_key: Vec = vec![false; batch.columns().len()]; + + // Figure out the columns used to optimize the way we compute the ranges. + // Since in IOx's use cases, every ingesting row is almost unique, the optimal way + // to get the ranges is to compare row by row from the highest cardinality column + // to the lowest one + // + // First get key columns which are the sort key columns in lowest to + // highest cardinality plus time column at the end + let mut columns: Vec<_> = self + .sort_keys + .iter() + .map(|skey| { + // figure out what input column this is for + let name = get_col_name(skey.expr.as_ref()); + let index = schema.index_of(name).unwrap(); + + is_sort_key[index] = true; + + Arc::clone(batch.column(index)) + }) + .collect(); + // + // Then converting the columns order from: lowest cardinality, second lowest, ..., highest cardinality, time + // to: highest cardinality, time, second highest cardinality, ...., lowest cardinality + // + // If the last column is time, swap time with its previous column (if any) which is + // the column with the highest cardinality + let len = columns.len(); + if len > 1 { + if let DataType::Timestamp(TimeUnit::Nanosecond, _) = columns[len - 1].data_type() { + columns.swap(len - 2, len - 1); + } + } + // Reverse the list + columns.reverse(); + + // Compute partitions (aka breakpoints between the ranges) + // Each range (or partition) includes a unique sort key value which is + // a unique combination of PK columns. PK columns consist of all tags and the time col. + let partitions = arrow::compute::partition(&columns)?; + let ranges = partitions.ranges(); + + Ok(DuplicateRanges { + is_sort_key, + ranges, + }) + } + + /// Compute the output record batch that includes the specified ranges + fn output_from_ranges( + &self, + batch: &RecordBatch, + dupe_ranges: &DuplicateRanges, + ) -> ArrowResult { + let ranges = &dupe_ranges.ranges; + + // each range is at least 1 large, so any that have more than + // 1 are duplicates + let num_dupes = ranges.iter().map(|r| r.end - r.start - 1).sum(); + + self.num_dupes.add(num_dupes); + + // Special case when no ranges are duplicated (so just emit input as output) + if num_dupes == 0 { + trace!(num_rows = batch.num_rows(), "No dupes"); + Self::slice_record_batch(batch, 0, ranges.len()) + } else { + trace!(num_dupes, num_rows = batch.num_rows(), "dupes"); + + // Use take kernel + let sort_key_indices = self.compute_sort_key_indices(ranges); + + let take_options = Some(TakeOptions { + check_bounds: false, + }); + + // Form each new column by `take`ing the indices as needed + let new_columns = batch + .columns() + .iter() + .enumerate() + .map(|(input_index, input_array)| { + if dupe_ranges.is_sort_key[input_index] { + arrow::compute::take( + input_array.as_ref(), + &sort_key_indices, + take_options.clone(), + ) + } else { + // pick the last non null value + let field_indices = self.compute_field_indices(ranges, input_array); + + arrow::compute::take( + input_array.as_ref(), + &field_indices, + take_options.clone(), + ) + } + }) + .collect::>>()?; + + let batch = RecordBatch::try_new(batch.schema(), new_columns)?; + // At time of writing, `MutableArrayData` concatenates the + // contents of dictionaries as well; Do a post pass to remove the + // redundancy if possible + optimize_dictionaries(&batch) + } + } + + /// Returns an array of indices, one for each input range (which + /// index is arbitrary as all the values are the same for the sort + /// column in each pk group) + /// + /// ranges: 0-1, 2-4, 5-6 --> Array[0, 2, 5] + fn compute_sort_key_indices(&self, ranges: &[Range]) -> UInt64Array { + ranges.iter().map(|r| Some(r.start as u64)).collect() + } + + /// Returns an array of indices, one for each input range that + /// return the first non-null value of `input_array` in that range + /// (aka it will pick the index of the field value to use for each + /// pk group) + /// + /// ranges: 0-1, 2-4, 5-6 + /// input array: A, NULL, NULL, C, NULL, NULL + /// --> Array[0, 3, 5] + fn compute_field_indices( + &self, + ranges: &[Range], + input_array: &ArrayRef, + ) -> UInt64Array { + ranges + .iter() + .map(|r| { + let value_index = r + .clone() + .filter(|&i| input_array.is_valid(i)) + .last() + .map(|i| i as u64) + // if all field values are none, pick one arbitrarily + .unwrap_or(r.start as u64); + Some(value_index) + }) + .collect() + } + + /// Create a new record batch from offset --> len + fn slice_record_batch( + batch: &RecordBatch, + offset: usize, + len: usize, + ) -> ArrowResult { + let batch = batch.slice(offset, len); + + // At time of writing, `concat_batches` concatenates the + // contents of dictionaries as well; Do a post pass to remove the + // redundancy if possible + optimize_dictionaries(&batch) + } +} + +/// Get column name out of the `expr`. TODO use +/// schema::SortKey instead. +pub(crate) fn get_col_name(expr: &dyn PhysicalExpr) -> &str { + expr.as_any() + .downcast_ref::() + .expect("expected column reference") + .name() +} + +#[cfg(test)] +mod test { + use arrow::array::{Int64Array, TimestampNanosecondArray}; + use arrow::compute::SortOptions; + use arrow::{ + array::{ArrayRef, Float64Array, StringArray}, + record_batch::RecordBatch, + }; + + use arrow_util::assert_batches_eq; + use datafusion::physical_plan::expressions::col; + use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; + + use super::*; + + #[tokio::test] + async fn test_non_overlapped_sorted_batches_one_key_column() { + // Sorted key: t1 + + // Last batch + // t1 | t2 | f1 | f2 + // ---+----+----+---- + // a | b | 1 | 2 + // a | c | 3 | + // a | c | 4 | + + // Current batch + // ====(next batch)==== + // b | c | | 6 + // b | d | 7 | 8 + + // Non overlapped => return last batch + // Expected output = Deduplication of Last batch + // t1 | t2 | f1 | f2 + // ---+----+----+---- + // a | c | 4 | 2 + + // Columns of last_batch + let t1 = StringArray::from(vec![Some("a"), Some("a"), Some("a")]); + let t2 = StringArray::from(vec![Some("b"), Some("c"), Some("c")]); + let f1 = Float64Array::from(vec![Some(1.0), Some(3.0), Some(4.0)]); + let f2 = Float64Array::from(vec![Some(2.0), None, None]); + + let last_batch = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("t2", Arc::new(t2) as ArrayRef), + ("f1", Arc::new(f1) as ArrayRef), + ("f2", Arc::new(f2) as ArrayRef), + ]) + .unwrap(); + + // Columns of current_batch + let t1 = StringArray::from(vec![Some("b"), Some("b")]); + let t2 = StringArray::from(vec![Some("c"), Some("d")]); + let f1 = Float64Array::from(vec![None, Some(7.0)]); + let f2 = Float64Array::from(vec![Some(6.0), Some(8.0)]); + + let current_batch = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("t2", Arc::new(t2) as ArrayRef), + ("f1", Arc::new(f1) as ArrayRef), + ("f2", Arc::new(f2) as ArrayRef), + ]) + .unwrap(); + + let sort_keys = vec![PhysicalSortExpr { + expr: col("t1", ¤t_batch.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + + let mut dedupe = RecordBatchDeduplicator::new(sort_keys, make_counter(), Some(last_batch)); + + let results = dedupe + .last_batch_with_no_same_sort_key(¤t_batch) + .unwrap(); + + let expected = vec![ + "+----+----+-----+-----+", + "| t1 | t2 | f1 | f2 |", + "+----+----+-----+-----+", + "| a | c | 4.0 | 2.0 |", + "+----+----+-----+-----+", + ]; + assert_batches_eq!(&expected, &[results]); + } + + #[tokio::test] + async fn test_non_overlapped_sorted_batches_two_key_columns() { + // Sorted key: t1, t2 + + // Last batch + // t1 | t2 | f1 | f2 + // ---+----+----+---- + // a | c | 1 | 2 + // a | c | 3 | + // a | c | 4 | 5 + + // Current batch + // ====(next batch)==== + // b | c | | 6 + // b | d | 7 | 8 + + // Non overlapped => return last batch + // Expected output = Deduplication of last batch + // t1 | t2 | f1 | f2 + // ---+----+----+---- + // a | c | 4 | 5 + + // Columns of last_batch + let t1 = StringArray::from(vec![Some("a"), Some("a"), Some("a")]); + let t2 = StringArray::from(vec![Some("c"), Some("c"), Some("c")]); + let f1 = Float64Array::from(vec![Some(1.0), Some(3.0), Some(4.0)]); + let f2 = Float64Array::from(vec![Some(2.0), None, Some(5.0)]); + + let last_batch = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("t2", Arc::new(t2) as ArrayRef), + ("f1", Arc::new(f1) as ArrayRef), + ("f2", Arc::new(f2) as ArrayRef), + ]) + .unwrap(); + + // Columns of current_batch + let t1 = StringArray::from(vec![Some("b"), Some("b")]); + let t2 = StringArray::from(vec![Some("c"), Some("d")]); + let f1 = Float64Array::from(vec![None, Some(7.0)]); + let f2 = Float64Array::from(vec![Some(6.0), Some(8.0)]); + + let current_batch = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("t2", Arc::new(t2) as ArrayRef), + ("f1", Arc::new(f1) as ArrayRef), + ("f2", Arc::new(f2) as ArrayRef), + ]) + .unwrap(); + + let sort_keys = vec![ + PhysicalSortExpr { + expr: col("t1", ¤t_batch.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + PhysicalSortExpr { + expr: col("t2", ¤t_batch.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + ]; + + let mut dedupe = RecordBatchDeduplicator::new(sort_keys, make_counter(), Some(last_batch)); + + let results = dedupe + .last_batch_with_no_same_sort_key(¤t_batch) + .unwrap(); + + let expected = vec![ + "+----+----+-----+-----+", + "| t1 | t2 | f1 | f2 |", + "+----+----+-----+-----+", + "| a | c | 4.0 | 5.0 |", + "+----+----+-----+-----+", + ]; + assert_batches_eq!(&expected, &[results]); + } + + #[tokio::test] + async fn test_overlapped_sorted_batches_one_key_column() { + // Sorted key: t1 + + // Last batch + // t1 | t2 | f1 | f2 + // ---+----+----+---- + // a | b | 1 | 2 + // a | b | 3 | + + // Current batch + // ====(next batch)==== + // a | b | | 6 + // b | d | 7 | 8 + + // Overlapped => return None + + // Columns of last_batch + let t1 = StringArray::from(vec![Some("a"), Some("a")]); + let t2 = StringArray::from(vec![Some("b"), Some("b")]); + let f1 = Float64Array::from(vec![Some(1.0), Some(3.0)]); + let f2 = Float64Array::from(vec![Some(2.0), None]); + + let last_batch = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("t2", Arc::new(t2) as ArrayRef), + ("f1", Arc::new(f1) as ArrayRef), + ("f2", Arc::new(f2) as ArrayRef), + ]) + .unwrap(); + + // Columns of current_batch + let t1 = StringArray::from(vec![Some("a"), Some("b")]); + let t2 = StringArray::from(vec![Some("b"), Some("d")]); + let f1 = Float64Array::from(vec![None, Some(7.0)]); + let f2 = Float64Array::from(vec![Some(6.0), Some(8.0)]); + + let current_batch = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("t2", Arc::new(t2) as ArrayRef), + ("f1", Arc::new(f1) as ArrayRef), + ("f2", Arc::new(f2) as ArrayRef), + ]) + .unwrap(); + + let sort_keys = vec![PhysicalSortExpr { + expr: col("t1", ¤t_batch.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + + let mut dedupe = RecordBatchDeduplicator::new(sort_keys, make_counter(), Some(last_batch)); + + let results = dedupe.last_batch_with_no_same_sort_key(¤t_batch); + assert!(results.is_none()); + } + + #[tokio::test] + async fn test_overlapped_sorted_batches_two_key_columns() { + // Sorted key: t1, t2 + + // Last batch + // t1 | t2 | f1 | f2 + // ---+----+----+---- + // a | b | 1 | 2 + // a | b | 3 | + + // Current batch + // ====(next batch)==== + // a | b | | 6 + // b | d | 7 | 8 + + // Overlapped => return None + + // Columns of last_batch + let t1 = StringArray::from(vec![Some("a"), Some("a")]); + let t2 = StringArray::from(vec![Some("b"), Some("b")]); + let f1 = Float64Array::from(vec![Some(1.0), Some(3.0)]); + let f2 = Float64Array::from(vec![Some(2.0), None]); + + let last_batch = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("t2", Arc::new(t2) as ArrayRef), + ("f1", Arc::new(f1) as ArrayRef), + ("f2", Arc::new(f2) as ArrayRef), + ]) + .unwrap(); + + // Columns of current_batch + let t1 = StringArray::from(vec![Some("a"), Some("b")]); + let t2 = StringArray::from(vec![Some("b"), Some("d")]); + let f1 = Float64Array::from(vec![None, Some(7.0)]); + let f2 = Float64Array::from(vec![Some(6.0), Some(8.0)]); + + let current_batch = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("t2", Arc::new(t2) as ArrayRef), + ("f1", Arc::new(f1) as ArrayRef), + ("f2", Arc::new(f2) as ArrayRef), + ]) + .unwrap(); + + let sort_keys = vec![ + PhysicalSortExpr { + expr: col("t1", ¤t_batch.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + PhysicalSortExpr { + expr: col("t2", ¤t_batch.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + ]; + + let mut dedupe = RecordBatchDeduplicator::new(sort_keys, make_counter(), Some(last_batch)); + + let results = dedupe.last_batch_with_no_same_sort_key(¤t_batch); + assert!(results.is_none()); + } + + #[tokio::test] + async fn test_non_overlapped_none_last_batch() { + // Sorted key: t1, t2 + + // Current batch + // ====(next batch)==== + // a | b | | 6 + // b | d | 7 | 8 + + // Columns of current_batch + let t1 = StringArray::from(vec![Some("a"), Some("b")]); + let t2 = StringArray::from(vec![Some("b"), Some("d")]); + let f1 = Float64Array::from(vec![None, Some(7.0)]); + let f2 = Float64Array::from(vec![Some(6.0), Some(8.0)]); + + let current_batch = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("t2", Arc::new(t2) as ArrayRef), + ("f1", Arc::new(f1) as ArrayRef), + ("f2", Arc::new(f2) as ArrayRef), + ]) + .unwrap(); + + let sort_keys = vec![ + PhysicalSortExpr { + expr: col("t1", ¤t_batch.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + PhysicalSortExpr { + expr: col("t2", ¤t_batch.schema()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + ]; + + let mut dedupe = RecordBatchDeduplicator::new(sort_keys, make_counter(), None); + + let results = dedupe.last_batch_with_no_same_sort_key(¤t_batch); + assert!(results.is_none()); + } + + #[tokio::test] + async fn test_compute_ranges() { + // Input columns: + // The input columns are sorted on this sort order: + // (Lowest_Cardinality, Second_Highest_Cardinality, Highest_Cardinality, Time) + // + // Invisible Index | Lowest_Cardinality | Second_Highest_Cardinality | Highest_Cardinality | Time + // (not a real col) + // --------------- | -------------------- | --------------------------- | ------------------- | ---- + // 0 | 1 | 1 | 1 | 1 + // 1 | 1 | 1 | 1 | 10 + // 2 | 1 | 1 | 3 | 8 + // 3 | 1 | 1 | 4 | 9 + // 4 | 1 | 1 | 4 | 9 + // 5 | 1 | 1 | 5 | 1 + // 6 | 1 | 1 | 5 | 15 + // 7 | 1 | 2 | 5 | 15 + // 8 | 1 | 2 | 5 | 15 + // 9 | 2 | 2 | 5 | 15 + // Out put ranges: 8 ranges on their invisible indices + // [0, 1], + // [1, 2], + // [2, 3], + // [3, 5], -- 2 rows with same values (1, 1, 4, 9) + // [5, 6], + // [6, 7], + // [7, 9], -- 2 rows with same values (1, 2, 5, 15) + // [9, 10], + + let mut lowest_cardinality = vec![Some("1"); 9]; // 9 first values are all Some(1) + lowest_cardinality.push(Some("2")); // Add Some(2) + let lowest_cardinality = Arc::new(StringArray::from(lowest_cardinality)) as ArrayRef; + + let mut second_highest_cardinality = vec![Some(1.0); 7]; + second_highest_cardinality.append(&mut vec![Some(2.0); 3]); + let second_highest_cardinality = + Arc::new(Float64Array::from(second_highest_cardinality)) as ArrayRef; + + let mut highest_cardinality = vec![Some(1), Some(1), Some(3), Some(4), Some(4)]; + highest_cardinality.append(&mut vec![Some(5); 5]); + let highest_cardinality = Arc::new(Int64Array::from(highest_cardinality)) as ArrayRef; + + let mut time = vec![Some(1), Some(10), Some(8), Some(9), Some(9), Some(1)]; + time.append(&mut vec![Some(15); 4]); + let time = Arc::new(TimestampNanosecondArray::from(time)) as ArrayRef; + + let batch = RecordBatch::try_from_iter(vec![ + ("lowest_cardinality", lowest_cardinality), + ("second_highest_cardinality", second_highest_cardinality), + ("highest_cardinality", highest_cardinality), + ("time", time), + ]) + .unwrap(); + + let options = SortOptions { + descending: false, + nulls_first: false, + }; + + let sort_keys = vec![ + PhysicalSortExpr { + expr: col("lowest_cardinality", &batch.schema()).unwrap(), + options, + }, + PhysicalSortExpr { + expr: col("second_highest_cardinality", &batch.schema()).unwrap(), + options, + }, + PhysicalSortExpr { + expr: col("highest_cardinality", &batch.schema()).unwrap(), + options, + }, + PhysicalSortExpr { + expr: col("time", &batch.schema()).unwrap(), + options, + }, + ]; + + let dedupe = RecordBatchDeduplicator::new(sort_keys, make_counter(), None); + let key_ranges = dedupe.compute_ranges(&batch).unwrap().ranges; + + let expected_key_range = vec![ + range(0, 1), + range(1, 2), + range(2, 3), + range(3, 5), + range(5, 6), + range(6, 7), + range(7, 9), + range(9, 10), + ]; + + assert_eq!(key_ranges, expected_key_range); + } + + fn make_counter() -> metrics::Count { + let metrics = ExecutionPlanMetricsSet::new(); + MetricBuilder::new(&metrics).counter("num_dupes", 0) + } + + fn range(start: usize, end: usize) -> Range { + Range { start, end } + } +} diff --git a/iox_query/src/provider/deduplicate/key_ranges.rs b/iox_query/src/provider/deduplicate/key_ranges.rs new file mode 100644 index 0000000..429e06c --- /dev/null +++ b/iox_query/src/provider/deduplicate/key_ranges.rs @@ -0,0 +1,281 @@ +//! Implement iterator and comparator to split data into distinct ranges + +use arrow::array::{build_compare, DynComparator}; +use arrow::buffer::NullBuffer; +use arrow::compute::{SortColumn, SortOptions}; +use arrow::error::{ArrowError, Result as ArrowResult}; + +// use snafu::Snafu; +use std::cmp::Ordering; +use std::iter::Iterator; +use std::ops::Range; + +/// Given a list of key columns, find partition ranges that would partition +/// equal values across columns +/// +/// The returned vec would be of size k where k is cardinality of the values; Consecutive +/// values will be connected: (a, b) and (b, c), where start = 0 and end = n for the first and last +/// range. +/// +/// The algorithm works with any set of data (no sort needed) and columns but it is implemented to optimize the use case in which: +/// 1. Every row is almost unique +/// 2. Order of input columns is from highest to lowest cardinality +/// +/// Example Input columns: +/// Invisible Index | Highest_Cardinality | Time | Second_Highest_Cardinality | Lowest_Cardinality +/// --------------- | -------------------- | ---- | -------------------------- | -------------------- +/// 0 | 1 | 1 | 1 | 1 +/// 1 | 1 | 10 | 1 | 1 +/// 2 | 3 | 8 | 1 | 1 +/// 3 | 4 | 9 | 1 | 1 +/// 4 | 4 | 9 | 1 | 1 +/// 5 | 5 | 1 | 1 | 1 +/// 6 | 5 | 15 | 1 | 1 +/// 7 | 5 | 15 | 2 | 1 +/// 8 | 5 | 15 | 2 | 1 +/// 9 | 5 | 15 | 2 | 2 +/// The columns are sorted (and RLE) on this different sort order: +/// (Lowest_Cardinality, Second_Highest_Cardinality, Highest_Cardinality, Time) +/// Out put ranges: 8 ranges on their invisible indices +/// [0, 1], +/// [1, 2], +/// [2, 3], +/// [3, 5], -- 2 rows with same values (4, 9, 1, 1) +/// [5, 6], +/// [6, 7], +/// [7, 9], -- 2 rows with same values (5, 15, 2, 1) +/// [9, 10] + +pub fn key_ranges(columns: &[SortColumn]) -> ArrowResult> + '_> { + KeyRangeIterator::try_new(columns) +} + +struct KeyRangeIterator<'a> { + // function to compare values of columns + comparator: KeyRangeComparator<'a>, + // Number of rows of the columns + num_rows: usize, + // end index of previous range which will be used as starting index of the next computing range + start_range_idx: usize, +} + +impl<'a> KeyRangeIterator<'a> { + fn try_new(columns: &'a [SortColumn]) -> ArrowResult { + if columns.is_empty() { + return Err(ArrowError::InvalidArgumentError( + "Key range requires at least one column".to_string(), + )); + } + let num_rows = columns[0].values.len(); + if columns.iter().any(|item| item.values.len() != num_rows) { + return Err(ArrowError::ComputeError( + "Sort columns have different row counts".to_string(), + )); + }; + + //let comparator = KeyRangeComparator::try_new(columns)?; + Ok(Self { + comparator: KeyRangeComparator::try_new(columns)?, + num_rows, + start_range_idx: 0, + }) + } +} + +impl<'a> Iterator for KeyRangeIterator<'a> { + type Item = Range; + + fn next(&mut self) -> Option { + // End of the row + if self.start_range_idx >= self.num_rows { + return None; + } + + let mut idx = self.start_range_idx + 1; + while idx < self.num_rows { + if self.comparator.compare(self.start_range_idx, idx) == Ordering::Equal { + idx += 1; + } else { + break; + } + } + let start = self.start_range_idx; + self.start_range_idx = idx; + Some(Range { start, end: idx }) + } +} + +type KeyRangeCompareItem<'a> = ( + Option<&'a NullBuffer>, // validity of array + DynComparator, // comparator + SortOptions, // sort_option +); + +// Todo: this is the same as LexicographicalComparator. +// Either use it or make it like https://github.com/apache/arrow-rs/issues/563 +/// A comparator that wraps given array data (columns) and can compare data +/// at given two indices. The lifetime is the same at the data wrapped. +pub(super) struct KeyRangeComparator<'a> { + compare_items: Vec>, +} + +fn is_valid(nulls: &Option<&NullBuffer>, idx: usize) -> bool { + nulls + .map(|nulls| nulls.is_valid(idx)) + // if there is no null buffer, the entry is valid + .unwrap_or(true) +} + +impl KeyRangeComparator<'_> { + /// compare values at the wrapped columns with given indices. + pub(super) fn compare(&self, a_idx: usize, b_idx: usize) -> Ordering { + for (nulls, comparator, sort_option) in &self.compare_items { + match (is_valid(nulls, a_idx), is_valid(nulls, b_idx)) { + (true, true) => { + match (comparator)(a_idx, b_idx) { + // equal, move on to next column + Ordering::Equal => continue, + order => { + if sort_option.descending { + return order.reverse(); + } else { + return order; + } + } + } + } + (false, true) => { + return if sort_option.nulls_first { + Ordering::Less + } else { + Ordering::Greater + }; + } + (true, false) => { + return if sort_option.nulls_first { + Ordering::Greater + } else { + Ordering::Less + }; + } + // equal, move on to next column + (false, false) => continue, + } + } + + Ordering::Equal + } + + /// Create a new comparator that will wrap the given columns and give comparison + /// results with two indices. + pub(super) fn try_new(columns: &[SortColumn]) -> ArrowResult> { + let compare_items = columns + .iter() + .map(|column| { + // flatten and convert build comparators + // use Nulls for is_valid checks later to avoid dynamic call + let values = column.values.as_ref(); + + let nulls = values.nulls(); + Ok(( + nulls, + build_compare(values, values)?, + column.options.unwrap_or_default(), + )) + }) + .collect::>>()?; + Ok(KeyRangeComparator { compare_items }) + } +} + +#[cfg(test)] +pub fn range(start: usize, end: usize) -> Range { + Range { start, end } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::ArrayRef; + use arrow::array::{Int64Array, TimestampNanosecondArray}; + + use super::*; + + #[tokio::test] + async fn test_key_ranges() { + // Input columns: + // Invisible Index | Highest_Cardinality | Time | Second_Highest_Cardinality | Lowest_Cardinality + // (not a real col) + // --------------- | -------------------- | ---- | -------------------------- | -------------------- + // 0 | 1 | 1 | 1 | 1 + // 1 | 1 | 10 | 1 | 1 + // 2 | 3 | 8 | 1 | 1 + // 3 | 4 | 9 | 1 | 1 + // 4 | 4 | 9 | 1 | 1 + // 5 | 5 | 1 | 1 | 1 + // 6 | 5 | 15 | 1 | 1 + // 7 | 5 | 15 | 2 | 1 + // 8 | 5 | 15 | 2 | 1 + // 9 | 5 | 15 | 2 | 2 + // The columns are sorted on this sort order: + // (Lowest_Cardinality, Second_Highest_Cardinality, Highest_Cardinality, Time) + // But when the key_ranges function is invoked, the input columns will be + // (Highest_Cardinality, Time, Second_Highest_Cardinality, Lowest_Cardinality) + // Out put ranges: 8 ranges on their invisible indices + // [0, 1], + // [1, 2], + // [2, 3], + // [3, 5], -- 2 rows with same values (4, 9, 1, 1) + // [5, 6], + // [6, 7], + // [7, 9], -- 2 rows with same values (5, 15, 2, 1) + // [9, 10], + + let mut lowest_cardinality = vec![Some(1); 9]; // 9 first values are all Some(1) + lowest_cardinality.push(Some(2)); // Add Some(2) + + let mut second_highest_cardinality = vec![Some(1); 7]; + second_highest_cardinality.append(&mut vec![Some(2); 3]); + + let mut time = vec![Some(1), Some(10), Some(8), Some(9), Some(9), Some(1)]; + time.append(&mut vec![Some(15); 4]); + + let mut highest_cardinality = vec![Some(1), Some(1), Some(3), Some(4), Some(4)]; + highest_cardinality.append(&mut vec![Some(5); 5]); + + let input = vec![ + SortColumn { + values: Arc::new(Int64Array::from(highest_cardinality)) as ArrayRef, + options: None, + }, + SortColumn { + values: Arc::new(TimestampNanosecondArray::from(time)) as ArrayRef, + options: None, + }, + SortColumn { + values: Arc::new(Int64Array::from(second_highest_cardinality)) as ArrayRef, + options: None, + }, + SortColumn { + values: Arc::new(Int64Array::from(lowest_cardinality)) as ArrayRef, + options: None, + }, + ]; + + let key_ranges = key_ranges(&input).unwrap(); + + let expected_key_range = vec![ + range(0, 1), + range(1, 2), + range(2, 3), + range(3, 5), + range(5, 6), + range(6, 7), + range(7, 9), + range(9, 10), + ]; + + assert_eq!(key_ranges.collect::>(), expected_key_range); + } +} diff --git a/iox_query/src/provider/overlap.rs b/iox_query/src/provider/overlap.rs new file mode 100644 index 0000000..4b90162 --- /dev/null +++ b/iox_query/src/provider/overlap.rs @@ -0,0 +1,392 @@ +//! Contains the algorithm to determine which chunks may contain "duplicate" primary keys (that is +//! where data with the same combination of "tag" columns and timestamp in the InfluxDB DataModel +//! have been written in via multiple distinct line protocol writes (and thus are stored in +//! separate rows) + +use crate::QueryChunk; +use data_types::TimestampMinMax; +use datafusion::scalar::ScalarValue; +use observability_deps::tracing::debug; +use schema::TIME_COLUMN_NAME; +use std::sync::Arc; + +/// Groups query chunks into disjoint sets of overlapped time range. +/// Does not preserve or guarantee any ordering. +pub fn group_potential_duplicates( + chunks: Vec>, +) -> Vec>> { + let ts: Vec<_> = chunks + .iter() + .map(|c| timestamp_min_max(c.as_ref())) + .collect(); + + // If at least one of the chunks has no time range, + // all chunks are considered to overlap with each other. + if ts.iter().any(|ts| ts.is_none()) { + debug!("At least one chunk has not timestamp min max"); + return vec![chunks]; + } + + // Use this algorithm to group them + // https://towardsdatascience.com/overlapping-time-period-problem-b7f1719347db + + let num_chunks = chunks.len(); + let mut grouper = Vec::with_capacity(num_chunks * 2); + + #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] + enum StartEnd { + Start, + End, + } + #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] + struct StartEndChunk { + start_end: StartEnd, + chunk: Option, + } + struct GrouperRecord { + value: V, + start_end_chunk: StartEndChunk, + } + + for (chunk, ts) in chunks.into_iter().zip(ts) { + let time_range = ts.expect("Time range should have value"); + + grouper.push(GrouperRecord { + value: time_range.min, + start_end_chunk: StartEndChunk { + start_end: StartEnd::Start, + chunk: None, + }, + }); + grouper.push(GrouperRecord { + value: time_range.max, + start_end_chunk: StartEndChunk { + start_end: StartEnd::End, + chunk: Some(chunk), + }, + }); + } + + grouper.sort_by_key(|gr| (gr.value, gr.start_end_chunk.start_end)); + + let mut cumulative_sum = 0; + let mut groups = Vec::with_capacity(num_chunks); + + for gr in grouper { + cumulative_sum += match gr.start_end_chunk.start_end { + StartEnd::Start => 1, + StartEnd::End => -1, + }; + + if matches!(gr.start_end_chunk.start_end, StartEnd::Start) && cumulative_sum == 1 { + groups.push(Vec::with_capacity(num_chunks)); + } + if let StartEnd::End = gr.start_end_chunk.start_end { + groups + .last_mut() + .expect("a start should have pushed at least one empty group") + .push(gr.start_end_chunk.chunk.expect("Must have chunk value")); + } + } + groups +} + +fn timestamp_min_max(chunk: &dyn QueryChunk) -> Option { + let stats = chunk.stats(); + chunk + .schema() + .find_index_of(TIME_COLUMN_NAME) + .map(|idx| &stats.column_statistics[idx]) + .and_then(|stats| { + if let ( + Some(ScalarValue::TimestampNanosecond(Some(min), _)), + Some(ScalarValue::TimestampNanosecond(Some(max), _)), + ) = (stats.min_value.get_value(), stats.max_value.get_value()) + { + Some(TimestampMinMax::new(*min, *max)) + } else { + None + } + }) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{test::TestChunk, QueryChunk}; + + #[macro_export] + macro_rules! assert_groups_eq { + ($EXPECTED_LINES: expr, $GROUPS: expr) => { + let expected_lines: Vec = + $EXPECTED_LINES.into_iter().map(|s| s.to_string()).collect(); + + let actual_lines = to_string($GROUPS); + + assert_eq!( + expected_lines, actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); + }; + } + + // Test cases: + + #[test] + fn one_time_column_overlap_same_min_max() { + let c1 = Arc::new( + TestChunk::new("chunk1") + .with_time_column() + .with_timestamp_min_max(1, 1), + ); + let c2 = Arc::new( + TestChunk::new("chunk2") + .with_time_column() + .with_timestamp_min_max(1, 1), + ); + + let groups = group_potential_duplicates(vec![c1, c2]); + + let expected = vec!["Group 0: [chunk1, chunk2]"]; + assert_groups_eq!(expected, groups); + } + + #[test] + fn one_time_column_overlap_bad_case() { + let c1 = Arc::new( + TestChunk::new("chunk1") + .with_time_column() + .with_timestamp_min_max(1, 10), + ); + let c2 = Arc::new( + TestChunk::new("chunk2") + .with_time_column() + .with_timestamp_min_max(15, 30), + ); + let c3 = Arc::new( + TestChunk::new("chunk3") + .with_time_column() + .with_timestamp_min_max(7, 20), + ); + let c4 = Arc::new( + TestChunk::new("chunk4") + .with_time_column() + .with_timestamp_min_max(25, 35), + ); + + let groups = group_potential_duplicates(vec![c1, c2, c3, c4]); + + let expected = vec!["Group 0: [chunk1, chunk3, chunk2, chunk4]"]; + assert_groups_eq!(expected, groups); + } + + #[test] + fn one_time_column_overlap_contiguous() { + let c1 = Arc::new( + TestChunk::new("chunk1") + .with_time_column() + .with_timestamp_min_max(1, 10), + ); + let c2 = Arc::new( + TestChunk::new("chunk2") + .with_time_column() + .with_timestamp_min_max(7, 20), + ); + let c3 = Arc::new( + TestChunk::new("chunk3") + .with_time_column() + .with_timestamp_min_max(15, 30), + ); + let c4 = Arc::new( + TestChunk::new("chunk4") + .with_time_column() + .with_timestamp_min_max(25, 35), + ); + + let groups = group_potential_duplicates(vec![c1, c2, c3, c4]); + + let expected = vec!["Group 0: [chunk1, chunk2, chunk3, chunk4]"]; + assert_groups_eq!(expected, groups); + } + + #[test] + fn one_time_column_overlap_2_groups() { + let c1 = Arc::new( + TestChunk::new("chunk1") + .with_time_column() + .with_timestamp_min_max(1, 10), + ); + let c2 = Arc::new( + TestChunk::new("chunk2") + .with_time_column() + .with_timestamp_min_max(7, 20), + ); + let c3 = Arc::new( + TestChunk::new("chunk3") + .with_time_column() + .with_timestamp_min_max(21, 30), + ); + let c4 = Arc::new( + TestChunk::new("chunk4") + .with_time_column() + .with_timestamp_min_max(25, 35), + ); + + let groups = group_potential_duplicates(vec![c1, c2, c3, c4]); + + let expected = vec!["Group 0: [chunk1, chunk2]", "Group 1: [chunk3, chunk4]"]; + assert_groups_eq!(expected, groups); + } + + #[test] + fn one_time_column_overlap_3_groups() { + let c1 = Arc::new( + TestChunk::new("chunk1") + .with_time_column() + .with_timestamp_min_max(1, 10), + ); + let c2 = Arc::new( + TestChunk::new("chunk2") + .with_time_column() + .with_timestamp_min_max(7, 20), + ); + let c3 = Arc::new( + TestChunk::new("chunk3") + .with_time_column() + .with_timestamp_min_max(21, 24), + ); + let c4 = Arc::new( + TestChunk::new("chunk4") + .with_time_column() + .with_timestamp_min_max(25, 35), + ); + + let groups = group_potential_duplicates(vec![c1, c4, c3, c2]); + + let expected = vec![ + "Group 0: [chunk1, chunk2]", + "Group 1: [chunk3]", + "Group 2: [chunk4]", + ]; + assert_groups_eq!(expected, groups); + } + + #[test] + fn one_time_column_overlap_1_chunk() { + let c1 = Arc::new( + TestChunk::new("chunk1") + .with_time_column() + .with_timestamp_min_max(1, 10), + ); + + let groups = group_potential_duplicates(vec![c1]); + + let expected = vec!["Group 0: [chunk1]"]; + assert_groups_eq!(expected, groups); + } + + #[test] + fn overlap_no_groups() { + let groups = group_potential_duplicates(vec![]); + + assert!(groups.is_empty()); + } + + #[test] + fn multi_columns_overlap_bad_case() { + let c1 = Arc::new( + TestChunk::new("chunk1") + .with_time_column() + .with_timestamp_min_max(1, 10), + ); + let c2 = Arc::new( + TestChunk::new("chunk2") + .with_time_column() + .with_timestamp_min_max(15, 30) + .with_i64_field_column("field1"), + ); + let c3 = Arc::new( + TestChunk::new("chunk3") + .with_time_column() + .with_timestamp_min_max(7, 20) + .with_tag_column("tag1"), + ); + let c4 = Arc::new( + TestChunk::new("chunk4") + .with_time_column() + .with_timestamp_min_max(25, 35), + ); + + let groups = group_potential_duplicates(vec![c1, c2, c3, c4]); + + let expected = vec!["Group 0: [chunk1, chunk3, chunk2, chunk4]"]; + assert_groups_eq!(expected, groups); + } + + #[test] + fn multi_columns_overlap_1_chunk() { + let c1 = Arc::new( + TestChunk::new("chunk1") + .with_time_column() + .with_timestamp_min_max(1, 10) + .with_tag_column("tag1"), + ); + + let groups = group_potential_duplicates(vec![c1]); + + let expected = vec!["Group 0: [chunk1]"]; + assert_groups_eq!(expected, groups); + } + + #[test] + fn multi_columns_overlap_3_groups() { + let c1 = Arc::new( + TestChunk::new("chunk1") + .with_time_column() + .with_timestamp_min_max(1, 10) + .with_tag_column("tag1"), + ); + let c2 = Arc::new( + TestChunk::new("chunk2") + .with_time_column() + .with_timestamp_min_max(7, 20), + ); + let c3 = Arc::new( + TestChunk::new("chunk3") + .with_time_column() + .with_timestamp_min_max(21, 24) + .with_tag_column("tag2"), + ); + let c4 = Arc::new( + TestChunk::new("chunk4") + .with_time_column() + .with_timestamp_min_max(25, 35), + ); + + let groups = group_potential_duplicates(vec![c1, c4, c3, c2]); + + let expected = vec![ + "Group 0: [chunk1, chunk2]", + "Group 1: [chunk3]", + "Group 2: [chunk4]", + ]; + assert_groups_eq!(expected, groups); + } + + // --- Test infrastructure -- + fn to_string(groups: Vec>>) -> Vec { + let mut s = vec![]; + for (idx, group) in groups.iter().enumerate() { + let names = group + .iter() + .map(|c| { + let c = c.as_any().downcast_ref::().unwrap(); + c.table_name() + }) + .collect::>(); + s.push(format!("Group {}: [{}]", idx, names.join(", "))); + } + s + } +} diff --git a/iox_query/src/provider/physical.rs b/iox_query/src/provider/physical.rs new file mode 100644 index 0000000..3114cf8 --- /dev/null +++ b/iox_query/src/provider/physical.rs @@ -0,0 +1,725 @@ +//! Implementation of a DataFusion PhysicalPlan node across partition chunks + +use crate::statistics::build_statistics_for_chunks; +use crate::{ + provider::record_batch_exec::RecordBatchesExec, util::arrow_sort_key_exprs, QueryChunk, + QueryChunkData, CHUNK_ORDER_COLUMN_NAME, +}; +use arrow::datatypes::{Fields, Schema as ArrowSchema, SchemaRef}; +use datafusion::{ + datasource::{ + listing::PartitionedFile, + object_store::ObjectStoreUrl, + physical_plan::{FileScanConfig, ParquetExec}, + }, + physical_expr::PhysicalSortExpr, + physical_plan::{empty::EmptyExec, expressions::Column, union::UnionExec, ExecutionPlan}, + scalar::ScalarValue, +}; +use object_store::ObjectMeta; +use schema::{sort::SortKey, Schema}; +use std::{ + collections::{hash_map::Entry, HashMap, HashSet}, + sync::Arc, +}; + +/// Extension for [`PartitionedFile`] to hold the original [`QueryChunk`] and the [`SortKey`] that was passed to [`chunks_to_physical_nodes`]. +pub struct PartitionedFileExt { + pub chunk: Arc, + pub output_sort_key_memo: Option, +} + +/// Holds a list of chunks that all have the same "URL" and +/// will be scanned using the same ParquetExec. +/// +/// Also tracks the overall sort key which is provided to DataFusion +/// plans +#[derive(Debug)] +struct ParquetChunkList { + object_store_url: ObjectStoreUrl, + chunks: Vec<(ObjectMeta, Arc)>, + /// Sort key to place on the ParquetExec, validated to be + /// compatible with all chunk sort keys + sort_key: Option, +} + +impl ParquetChunkList { + /// Create a new chunk list with the specified chunk and overall + /// sort order. If the desired output sort key is specified + /// (e.g. the partition sort key) also computes compatibility with + /// with the chunk order. + fn new( + object_store_url: ObjectStoreUrl, + chunk: &Arc, + meta: ObjectMeta, + output_sort_key: Option<&SortKey>, + ) -> Self { + let sort_key = combine_sort_key(output_sort_key.cloned(), chunk.sort_key(), chunk.schema()); + + Self { + object_store_url, + chunks: vec![(meta, Arc::clone(chunk))], + sort_key, + } + } + + /// Add the parquet file the list of files to be scanned, updating + /// the sort key as necessary. + fn add_parquet_file(&mut self, chunk: &Arc, meta: ObjectMeta) { + self.chunks.push((meta, Arc::clone(chunk))); + + self.sort_key = combine_sort_key(self.sort_key.take(), chunk.sort_key(), chunk.schema()); + } +} + +/// Combines the existing sort key with the sort key of the chunk, +/// returning the new combined compatible sort key that describes both +/// chunks. +/// +/// If it is not possible to find a compatible sort key, None is +/// returned signifying "unknown sort order" +fn combine_sort_key( + existing_sort_key: Option, + chunk_sort_key: Option<&SortKey>, + chunk_schema: &Schema, +) -> Option { + if let (Some(existing_sort_key), Some(chunk_sort_key)) = (existing_sort_key, chunk_sort_key) { + let combined_sort_key = SortKey::try_merge_key(&existing_sort_key, chunk_sort_key); + + if let Some(combined_sort_key) = combined_sort_key { + let chunk_cols = chunk_schema + .iter() + .map(|(_t, field)| field.name().as_str()) + .collect::>(); + for (col, _opts) in combined_sort_key.iter() { + if !chunk_sort_key.contains(col.as_ref()) && chunk_cols.contains(col.as_ref()) { + return None; + } + } + } + + // Avoid cloning the sort key when possible, as the sort key + // is likely to commonly be the same + match combined_sort_key { + Some(combined_sort_key) if combined_sort_key == &existing_sort_key => { + Some(existing_sort_key) + } + Some(combined_sort_key) => Some(combined_sort_key.clone()), + None => None, + } + } else { + // no existing sort key means the data wasn't consistently sorted so leave it alone + None + } +} + +/// Place [chunk](QueryChunk)s into physical nodes. +/// +/// This will group chunks into [record batch](QueryChunkData::RecordBatches) and [parquet +/// file](QueryChunkData::Parquet) chunks. The latter will also be grouped by store. +/// +/// Record batch chunks will be turned into a single [`RecordBatchesExec`]. +/// +/// Parquet chunks will be turned into a [`ParquetExec`] per store, each of them with +/// [`target_partitions`](datafusion::execution::context::SessionConfig::target_partitions) file groups. +/// +/// If this function creates more than one physical node, they will be combined using an [`UnionExec`]. Otherwise, a +/// single node will be returned directly. +/// +/// If output_sort_key is specified, the ParquetExec will be marked +/// with that sort key, otherwise it will be computed from the input chunks. TODO check if this is helpful or not +/// +/// # Empty Inputs +/// For empty inputs (i.e. no chunks), this will create a single [`EmptyExec`] node with appropriate schema. +/// +/// # Predicates +/// The give `predicate` will only be applied to [`ParquetExec`] nodes since they are the only node type benifiting from +/// pushdown ([`RecordBatchesExec`] has NO builtin filter function). Delete predicates are NOT applied at all. The +/// caller is responsible for wrapping the output node into appropriate filter nodes. +pub fn chunks_to_physical_nodes( + schema: &SchemaRef, + output_sort_key: Option<&SortKey>, + chunks: Vec>, + target_partitions: usize, +) -> Arc { + if chunks.is_empty() { + return Arc::new(EmptyExec::new(Arc::clone(schema))); + } + + let mut record_batch_chunks: Vec> = vec![]; + let mut parquet_chunks: HashMap = HashMap::new(); + + for chunk in &chunks { + match chunk.data() { + QueryChunkData::RecordBatches(_) => { + record_batch_chunks.push(Arc::clone(chunk)); + } + QueryChunkData::Parquet(parquet_input) => { + let url_str = parquet_input.object_store_url.as_str().to_owned(); + match parquet_chunks.entry(url_str) { + Entry::Occupied(mut o) => { + o.get_mut() + .add_parquet_file(chunk, parquet_input.object_meta); + } + Entry::Vacant(v) => { + // better have some instead of no sort information at all + let output_sort_key = output_sort_key.or_else(|| chunk.sort_key()); + v.insert(ParquetChunkList::new( + parquet_input.object_store_url, + chunk, + parquet_input.object_meta, + output_sort_key, + )); + } + } + } + } + } + + let mut output_nodes: Vec> = vec![]; + if !record_batch_chunks.is_empty() { + output_nodes.push(Arc::new(RecordBatchesExec::new( + record_batch_chunks, + Arc::clone(schema), + output_sort_key.cloned(), + ))); + } + let mut parquet_chunks: Vec<_> = parquet_chunks.into_iter().collect(); + parquet_chunks.sort_by_key(|(url_str, _)| url_str.clone()); + let has_chunk_order_col = schema.field_with_name(CHUNK_ORDER_COLUMN_NAME).is_ok(); + for (_url_str, chunk_list) in parquet_chunks { + let ParquetChunkList { + object_store_url, + mut chunks, + sort_key, + } = chunk_list; + + // ensure that chunks are actually ordered by chunk order + chunks.sort_by_key(|(_meta, c)| c.order()); + + // Compute statistics for the chunks + let query_chunks = chunks + .iter() + .map(|(_meta, chunk)| Arc::clone(chunk)) + .collect::>(); + let statistics = build_statistics_for_chunks(&query_chunks, Arc::clone(schema)); + + let file_groups = distribute( + chunks.into_iter().map(|(object_meta, chunk)| { + let partition_values = if has_chunk_order_col { + vec![ScalarValue::from(chunk.order().get())] + } else { + vec![] + }; + PartitionedFile { + object_meta, + partition_values, + range: None, + extensions: Some(Arc::new(PartitionedFileExt { + chunk, + output_sort_key_memo: output_sort_key.cloned(), + })), + } + }), + target_partitions, + ); + + // Tell datafusion about the sort key, if any + let output_ordering = sort_key.map(|sort_key| arrow_sort_key_exprs(&sort_key, schema)); + + let (table_partition_cols, file_schema, output_ordering) = if has_chunk_order_col { + let table_partition_cols = vec![schema + .field_with_name(CHUNK_ORDER_COLUMN_NAME) + .unwrap() + .clone()]; + let file_schema = Arc::new(ArrowSchema::new( + schema + .fields + .iter() + .filter(|f| f.name() != CHUNK_ORDER_COLUMN_NAME) + .map(Arc::clone) + .collect::(), + )); + let output_ordering = Some( + output_ordering + .unwrap_or_default() + .into_iter() + .chain(std::iter::once(PhysicalSortExpr { + expr: Arc::new( + Column::new_with_schema(CHUNK_ORDER_COLUMN_NAME, schema) + .expect("just added col"), + ), + options: Default::default(), + })) + .collect::>(), + ); + (table_partition_cols, file_schema, output_ordering) + } else { + (vec![], Arc::clone(schema), output_ordering) + }; + + // No sort order is represented by an empty Vec + let output_ordering = vec![output_ordering.unwrap_or_default()]; + + let base_config = FileScanConfig { + object_store_url, + file_schema, + file_groups, + statistics, + projection: None, + limit: None, + table_partition_cols, + output_ordering, + }; + let meta_size_hint = None; + + let parquet_exec = ParquetExec::new(base_config, None, meta_size_hint); + output_nodes.push(Arc::new(parquet_exec)); + } + + assert!(!output_nodes.is_empty()); + Arc::new(UnionExec::new(output_nodes)) +} + +/// Distribute items from the given iterator into `n` containers. +/// +/// This will produce less than `n` containers if the input has less than `n` elements. +/// +/// # Panic +/// Panics if `n` is 0. +fn distribute(it: I, n: usize) -> Vec> +where + I: IntoIterator, +{ + assert!(n > 0); + + let mut outputs: Vec<_> = (0..n).map(|_| vec![]).collect(); + let mut pos = 0usize; + for x in it { + outputs[pos].push(x); + pos = (pos + 1) % n; + } + outputs.into_iter().filter(|o| !o.is_empty()).collect() +} + +#[cfg(test)] +mod tests { + use datafusion::{ + common::stats::Precision, + physical_plan::{ColumnStatistics, Statistics}, + }; + use schema::{sort::SortKeyBuilder, InfluxFieldType, SchemaBuilder, TIME_COLUMN_NAME}; + + use crate::{ + chunk_order_field, + statistics::build_statistics_for_chunks, + test::{format_execution_plan, TestChunk}, + }; + + use super::*; + + #[test] + fn test_distribute() { + assert_eq!(distribute(0..0u8, 1), Vec::>::new(),); + + assert_eq!(distribute(0..3u8, 1), vec![vec![0, 1, 2]],); + + assert_eq!(distribute(0..3u8, 2), vec![vec![0, 2], vec![1]],); + + assert_eq!(distribute(0..3u8, 10), vec![vec![0], vec![1], vec![2]],); + } + + #[test] + fn test_combine_sort_key() { + let schema_t1 = SchemaBuilder::new().tag("t1").timestamp().build().unwrap(); + let skey_t1 = SortKeyBuilder::new() + .with_col("t1") + .with_col(TIME_COLUMN_NAME) + .build(); + + let schema_t1_t2 = SchemaBuilder::new() + .tag("t1") + .tag("t2") + .timestamp() + .build() + .unwrap(); + let skey_t1_t2 = SortKeyBuilder::new() + .with_col("t1") + .with_col("t2") + .with_col(TIME_COLUMN_NAME) + .build(); + + let skey_t2_t1 = SortKeyBuilder::new() + .with_col("t2") + .with_col("t1") + .with_col(TIME_COLUMN_NAME) + .build(); + + // output is None if any of the parameters is None (either no sort key requested or chunk is unsorted) + assert_eq!(combine_sort_key(None, None, &schema_t1), None); + assert_eq!( + combine_sort_key(Some(skey_t1.clone()), None, &schema_t1), + None + ); + assert_eq!(combine_sort_key(None, Some(&skey_t1), &schema_t1), None); + + // keeping sort key identical works + assert_eq!( + combine_sort_key(Some(skey_t1.clone()), Some(&skey_t1), &schema_t1), + Some(skey_t1.clone()) + ); + assert_eq!( + combine_sort_key(Some(skey_t1.clone()), Some(&skey_t1), &schema_t1_t2), + Some(skey_t1.clone()) + ); + + // extending sort key works (chunk has more columns than existing key) + assert_eq!( + combine_sort_key(Some(skey_t1.clone()), Some(&skey_t1_t2), &schema_t1_t2), + Some(skey_t1_t2.clone()) + ); + + // extending sort key works (quorum has more columns than this chunk) + assert_eq!( + combine_sort_key(Some(skey_t1_t2.clone()), Some(&skey_t1), &schema_t1), + Some(skey_t1_t2.clone()) + ); + assert_eq!( + combine_sort_key(Some(skey_t2_t1.clone()), Some(&skey_t1), &schema_t1), + Some(skey_t2_t1.clone()) + ); + + // extending does not work if quorum covers columns that the chunk has but that are NOT sorted for that chunk + assert_eq!( + combine_sort_key(Some(skey_t1_t2.clone()), Some(&skey_t1), &schema_t1_t2), + None + ); + assert_eq!( + combine_sort_key(Some(skey_t2_t1.clone()), Some(&skey_t1), &schema_t1_t2), + None + ); + + // column order conflicts are detected + assert_eq!( + combine_sort_key(Some(skey_t2_t1), Some(&skey_t1_t2), &schema_t1_t2), + None + ); + } + + #[test] + fn test_chunks_to_physical_nodes_empty() { + let schema = TestChunk::new("table").schema().as_arrow(); + let plan = chunks_to_physical_nodes(&schema, None, vec![], 2); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " EmptyExec" + "### + ); + } + + #[test] + fn test_chunks_to_physical_nodes_recordbatch() { + let chunk = TestChunk::new("table"); + let schema = chunk.schema().as_arrow(); + let plan = chunks_to_physical_nodes(&schema, None, vec![Arc::new(chunk)], 2); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " UnionExec" + - " RecordBatchesExec: chunks=1" + "### + ); + } + + #[test] + fn test_chunks_to_physical_nodes_parquet_one_file() { + let chunk = TestChunk::new("table").with_dummy_parquet_file(); + let schema = chunk.schema().as_arrow(); + let plan = chunks_to_physical_nodes(&schema, None, vec![Arc::new(chunk)], 2); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}" + "### + ); + } + + #[test] + fn test_chunks_to_physical_nodes_parquet_many_files() { + let chunk1 = TestChunk::new("table").with_id(0).with_dummy_parquet_file(); + let chunk2 = TestChunk::new("table").with_id(1).with_dummy_parquet_file(); + let chunk3 = TestChunk::new("table").with_id(2).with_dummy_parquet_file(); + let schema = chunk1.schema().as_arrow(); + let plan = chunks_to_physical_nodes( + &schema, + None, + vec![Arc::new(chunk1), Arc::new(chunk2), Arc::new(chunk3)], + 2, + ); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " UnionExec" + - " ParquetExec: file_groups={2 groups: [[0.parquet, 2.parquet], [1.parquet]]}" + "### + ); + } + + #[test] + fn test_chunks_to_physical_nodes_parquet_many_store() { + let chunk1 = TestChunk::new("table") + .with_id(0) + .with_dummy_parquet_file_and_store("iox1://"); + let chunk2 = TestChunk::new("table") + .with_id(1) + .with_dummy_parquet_file_and_store("iox2://"); + let schema = chunk1.schema().as_arrow(); + let plan = + chunks_to_physical_nodes(&schema, None, vec![Arc::new(chunk1), Arc::new(chunk2)], 2); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}" + - " ParquetExec: file_groups={1 group: [[1.parquet]]}" + "### + ); + } + + #[test] + fn test_chunks_to_physical_nodes_mixed() { + let chunk1 = TestChunk::new("table").with_dummy_parquet_file(); + let chunk2 = TestChunk::new("table"); + let schema = chunk1.schema().as_arrow(); + let plan = + chunks_to_physical_nodes(&schema, None, vec![Arc::new(chunk1), Arc::new(chunk2)], 2); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " UnionExec" + - " RecordBatchesExec: chunks=1" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}" + "### + ); + } + + #[test] + fn test_chunks_to_physical_nodes_mixed_with_chunk_order() { + let chunk1 = TestChunk::new("table") + .with_tag_column("tag") + .with_dummy_parquet_file(); + let chunk2 = TestChunk::new("table").with_tag_column("tag"); + let schema = Arc::new(ArrowSchema::new( + chunk1 + .schema() + .as_arrow() + .fields + .iter() + .cloned() + .chain(std::iter::once(chunk_order_field())) + .collect::(), + )); + let plan = + chunks_to_physical_nodes(&schema, None, vec![Arc::new(chunk1), Arc::new(chunk2)], 2); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan), + @r###" + --- + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[tag, __chunk_order]" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[tag, __chunk_order], output_ordering=[__chunk_order@1 ASC]" + "### + ); + } + + // reproducer of https://github.com/influxdata/idpe/issues/18287 + #[test] + fn reproduce_schema_bug_in_parquet_exec() { + // schema with one tag, one filed, time and CHUNK_ORDER_COLUMN_NAME + let schema: SchemaRef = SchemaBuilder::new() + .tag("tag") + .influx_field("field", InfluxFieldType::Float) + .timestamp() + .influx_field(CHUNK_ORDER_COLUMN_NAME, InfluxFieldType::Integer) + .build() + .unwrap() + .into(); + + // create a test chunk with one tag, one filed, time and CHUNK_ORDER_COLUMN_NAME + let record_batch_chunk = Arc::new( + TestChunk::new("t") + .with_tag_column_with_stats("tag", Some("AL"), Some("MT")) + .with_time_column_with_stats(Some(10), Some(20)) + .with_i64_field_column_with_stats("field", Some(0), Some(100)) + .with_i64_field_column_with_stats(CHUNK_ORDER_COLUMN_NAME, Some(5), Some(6)), + ); + + // create them same test chunk but with a parquet file + let parquet_chunk = Arc::new( + TestChunk::new("t") + .with_tag_column_with_stats("tag", Some("AL"), Some("MT")) + .with_i64_field_column_with_stats("field", Some(0), Some(100)) + .with_time_column_with_stats(Some(10), Some(20)) + .with_i64_field_column_with_stats(CHUNK_ORDER_COLUMN_NAME, Some(5), Some(6)) + .with_dummy_parquet_file(), + ); + + // Build a RecordBatchsExec for record_batch_chunk + // + // Use chunks_to_physical_nodes to build a plan with UnionExec on top of RecordBatchesExec + // Note: I purposely use chunks_to_physical_node to create plan for both record_batch_chunk and parquet_chunk to + // consistently create their plan. Also chunks_to_physical_node is used to do create plan in optimization + // passes that I will need + let plan = chunks_to_physical_nodes( + &schema, + None, + vec![Arc::clone(&record_batch_chunk) as Arc], + 1, + ); + // remove union + let Some(union_exec) = plan.as_any().downcast_ref::() else { + panic!("plan is not a UnionExec"); + }; + let plan_record_batches_exec = Arc::clone(&union_exec.inputs()[0]); + // verify this is a RecordBatchesExec + assert!(plan_record_batches_exec + .as_any() + .downcast_ref::() + .is_some()); + + // Build a ParquetExec for parquet_chunk + // + // Use chunks_to_physical_nodes to build a plan with UnionExec on top of ParquetExec + let plan = chunks_to_physical_nodes( + &schema, + None, + vec![Arc::clone(&parquet_chunk) as Arc], + 1, + ); + // remove union + let Some(union_exec) = plan.as_any().downcast_ref::() else { + panic!("plan is not a UnionExec"); + }; + let plan_parquet_exec = Arc::clone(&union_exec.inputs()[0]); + // verify this is a ParquetExec + assert!(plan_parquet_exec + .as_any() + .downcast_ref::() + .is_some()); + + // Schema of 2 chunks are the same + assert_eq!(record_batch_chunk.schema(), parquet_chunk.schema()); + + // Schema of the corresponding plans are also the same + assert_eq!( + plan_record_batches_exec.schema(), + plan_parquet_exec.schema() + ); + + // Statistics of 2 chunks are the same + let record_batch_stats = + build_statistics_for_chunks(&[record_batch_chunk], Arc::clone(&schema)); + let parquet_stats = build_statistics_for_chunks(&[parquet_chunk], schema); + assert_eq!(record_batch_stats, parquet_stats); + + // Statistics of the corresponding plans should also be the same except the CHUNK_ORDER_COLUMN_NAME + // Notes: + // 1. We do compute stats for CHUNK_ORDER_COLUMN_NAME and store it as in FileScanConfig.statistics + // See: https://github.com/influxdata/influxdb_iox/blob/0e5b97d9e913111641f65b9af31e3b3f45f3b14b/iox_query/src/provider/physical.rs#L311C24-L311C24 + // So, if we get statistics there, we have everything + // 2. However, if we get statistics through the DF plan's statistics() method, we will not get stats for CHUNK_ORDER_COLUMN_NAME + // The reason is we store CHUNK_ORDER_COLUMN_NAME as table_partition_cols in DF and DF has not computed stats for it yet. + // See: https://github.com/apache/arrow-datafusion/blob/a9d66e2b492843c2fb335a7dfe27fed073629b09/datafusion/core/src/datasource/physical_plan/file_scan_config.rs#L139 + // When we get the plan's statistics, we won't care about CHUNK_ORDER_COLUMN_NAME becasue it is not a real column. + // Thus, we are good for now. In the future, if we want a 100% consistent for CHUNK_ORDER_COLUMN_NAME, we need + // to modify DF to compute stats for table_partition_cols + // + // Here both parquet's plan stats and FileScanConfig stats + // + // Cast to ParquetExec to get statistics + let plan_parquet_exec = plan_parquet_exec + .as_any() + .downcast_ref::() + .unwrap(); + // stats of the parquet plan generally computed from propagating stats from input plans/chunks/columns + let parquet_plan_stats = plan_parquet_exec.statistics().unwrap(); + // stats stored in FileScanConfig + let parqet_file_stats = &plan_parquet_exec.base_config().statistics; + + // stats of IOx specific recod batch plan + let record_batch_plan_stats = plan_record_batches_exec.statistics().unwrap(); + + // Record batch plan stats is the same as parquet file stats and includes everything + assert_eq!(record_batch_plan_stats, *parqet_file_stats); + + // Verify content + // + // Actual columns have stats + let col_stats = vec![ + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Utf8(Some("MT".to_string()))), + min_value: Precision::Exact(ScalarValue::Utf8(Some("AL".to_string()))), + distinct_count: Precision::Absent, + }, + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Int64(Some(100))), + min_value: Precision::Exact(ScalarValue::Int64(Some(0))), + distinct_count: Precision::Absent, + }, + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::TimestampNanosecond(Some(20), None)), + min_value: Precision::Exact(ScalarValue::TimestampNanosecond(Some(10), None)), + distinct_count: Precision::Absent, + }, + ]; + // + // Add CHUNK_ORDER_COLUMN_NAME with stats + let mut parquet_file_col_stats = col_stats.clone(); + parquet_file_col_stats.push(ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Int64(Some(6))), + min_value: Precision::Exact(ScalarValue::Int64(Some(0))), + distinct_count: Precision::Absent, + }); + // + // Add CHUNK_ORDER_COLUMN_NAME without stats + let mut parquet_plan_stats_col_stats = col_stats; + parquet_plan_stats_col_stats.push(ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + distinct_count: Precision::Absent, + }); + // + let expected_parquet_plan_stats = Statistics { + num_rows: Precision::Exact(0), + total_byte_size: Precision::Absent, + column_statistics: parquet_plan_stats_col_stats, + }; + // + let expected_parquet_file_stats = Statistics { + num_rows: Precision::Exact(0), + total_byte_size: Precision::Absent, + column_statistics: parquet_file_col_stats, + }; + + // Content of Record batch plan stats that include stats of CHUNK_ORDER_COLUMN_NAME + assert_eq!(record_batch_plan_stats, expected_parquet_file_stats); + // Content of parquet file stats that also include stats of CHUNK_ORDER_COLUMN_NAME + assert_eq!(*parqet_file_stats, expected_parquet_file_stats); + // + // Content of parquet plan stats that does not include stats of CHUNK_ORDER_COLUMN_NAME + assert_eq!(parquet_plan_stats, expected_parquet_plan_stats); + } +} diff --git a/iox_query/src/provider/progressive_eval.rs b/iox_query/src/provider/progressive_eval.rs new file mode 100644 index 0000000..80109e4 --- /dev/null +++ b/iox_query/src/provider/progressive_eval.rs @@ -0,0 +1,1206 @@ +// ProgressiveEvalExec (step 1 in https://docs.google.com/document/d/1x1yf9ggyxD4JPT8Gf9YlIKxUawqoKTJ1HFyTbGin9xY/edit) +// This will be moved to DF once it is ready + +//! Defines the progressive eval plan + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion::common::{internal_err, DataFusionError, Result}; +use datafusion::execution::TaskContext; +use datafusion::physical_expr::{EquivalenceProperties, PhysicalSortExpr, PhysicalSortRequirement}; +use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; +use datafusion::physical_plan::stream::RecordBatchReceiverStream; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, RecordBatchStream, + SendableRecordBatchStream, Statistics, +}; +use datafusion::scalar::ScalarValue; +use futures::{ready, Stream, StreamExt}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use observability_deps::tracing::{debug, trace}; + +/// ProgressiveEval return a stream of record batches in the order of its inputs. +/// It will stop when the number of output rows reach the given limit. +/// +/// This takes an input execution plan and a number n, and provided each partition of +/// the input plan is in an expected order, this operator will return top record batches that covers the top n rows +/// in the order of the input plan. +/// +/// ```text +/// ┌─────────────────────────┐ +/// │ ┌───┬───┬───┬───┐ │ +/// │ │ A │ B │ C │ D │ │──┐ +/// │ └───┴───┴───┴───┘ │ │ +/// └─────────────────────────┘ │ ┌───────────────────┐ ┌───────────────────────────────┐ +/// Stream 1 │ │ │ │ ┌───┬───╦═══╦───┬───╦═══╗ │ +/// ├─▶│ ProgressiveEval │───▶│ │ A │ B ║ C ║ D │ M ║ N ║ ... │ +/// │ │ │ │ └───┴─▲─╩═══╩───┴───╩═══╝ │ +/// ┌─────────────────────────┐ │ └───────────────────┘ └─┬─────┴───────────────────────┘ +/// │ ╔═══╦═══╗ │ │ +/// │ ║ M ║ N ║ │──┘ │ +/// │ ╚═══╩═══╝ │ Output only include top record batches that cover top N rows +/// └─────────────────────────┘ +/// Stream 2 +/// +/// +/// Input Streams Output stream +/// (in some order) (in same order) +/// ``` +#[derive(Debug)] +pub(crate) struct ProgressiveEvalExec { + /// Input plan + input: Arc, + + /// Corresponding value ranges of the input plan + /// None if the value ranges are not available + value_ranges: Option>, + + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + + /// Optional number of rows to fetch. Stops producing rows after this fetch + fetch: Option, +} + +impl ProgressiveEvalExec { + /// Create a new progressive execution plan + pub fn new( + input: Arc, + value_ranges: Option>, + fetch: Option, + ) -> Self { + Self { + input, + value_ranges, + metrics: ExecutionPlanMetricsSet::new(), + fetch, + } + } + + /// Input schema + pub fn input(&self) -> &Arc { + &self.input + } +} + +impl DisplayAs for ProgressiveEvalExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "ProgressiveEvalExec: ")?; + if let Some(fetch) = self.fetch { + write!(f, "fetch={fetch}, ")?; + }; + if let Some(value_ranges) = &self.value_ranges { + write!(f, "input_ranges={value_ranges:?}")?; + }; + + Ok(()) + } + } + } +} + +impl ExecutionPlan for ProgressiveEvalExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.input.schema() + } + + /// Get the output partitioning of this plan + fn output_partitioning(&self) -> Partitioning { + // This node serializes all the data to a single partition + Partitioning::UnknownPartitioning(1) + } + + /// Specifies whether this plan generates an infinite stream of records. + /// If the plan does not support pipelining, but its input(s) are + /// infinite, returns an error to indicate this. + fn unbounded_output(&self, children: &[bool]) -> Result { + Ok(children[0]) + } + + fn required_input_distribution(&self) -> Vec { + vec![Distribution::UnspecifiedDistribution] + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] + } + + fn required_input_ordering(&self) -> Vec>> { + self.input() + .output_ordering() + .map(|_| None) + .into_iter() + .collect() + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.input.output_ordering() + } + + /// ProgressiveEvalExec will only accept sorted input + /// and will maintain the input order + fn maintains_input_order(&self) -> Vec { + vec![true] + } + + fn children(&self) -> Vec> { + vec![Arc::::clone(&self.input)] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::new( + Arc::::clone(&children[0]), + self.value_ranges.clone(), + self.fetch, + ))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + trace!( + "Start ProgressiveEvalExec::execute for partition: {}", + partition + ); + if 0 != partition { + return internal_err!("ProgressiveEvalExec invalid partition {partition}"); + } + + let input_partitions = self.input.output_partitioning().partition_count(); + trace!( + "Number of input partitions of ProgressiveEvalExec::execute: {}", + input_partitions + ); + let schema = self.schema(); + + // Have the input streams run in parallel + // todo: maybe in the future we do not need this parallelism if number of fecthed rows is in the fitst stream + let receivers = (0..input_partitions) + .map(|partition| { + let stream = self + .input + .execute(partition, Arc::::clone(&context))?; + + Ok(spawn_buffered(stream, 1)) + }) + .collect::>()?; + + debug!("Done setting up sender-receiver for ProgressiveEvalExec::execute"); + + let result = ProgressiveEvalStream::new( + receivers, + schema, + BaselineMetrics::new(&self.metrics, partition), + self.fetch, + )?; + + debug!("Got stream result from ProgressiveEvalStream::new_from_receivers"); + + Ok(Box::pin(result)) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + self.input.statistics() + } + + fn equivalence_properties(&self) -> EquivalenceProperties { + // progressive eval does not change the equivalence properties of its input + self.input.equivalence_properties() + } +} + +/// Concat input streams until reaching the fetch limit +struct ProgressiveEvalStream { + /// input streams + input_streams: Vec, + + /// The schema of the input and output. + schema: SchemaRef, + + /// used to record execution metrics + metrics: BaselineMetrics, + + /// Index of current stream + current_stream_idx: usize, + + /// If the stream has encountered an error + aborted: bool, + + /// Optional number of rows to fetch + fetch: Option, + + /// number of rows produced + produced: usize, +} + +impl ProgressiveEvalStream { + fn new( + input_streams: Vec, + schema: SchemaRef, + metrics: BaselineMetrics, + fetch: Option, + ) -> Result { + Ok(Self { + input_streams, + schema, + metrics, + current_stream_idx: 0, + aborted: false, + fetch, + produced: 0, + }) + } +} + +impl Stream for ProgressiveEvalStream { + type Item = Result; + + // Return the next record batch until reaching the fetch limit or the end of all input streams + // Return pending if the next record batch is not ready + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Error in previous poll + if self.aborted { + return Poll::Ready(None); + } + + // Have reached the fetch limit + if self.produced >= self.fetch.unwrap_or(std::usize::MAX) { + return Poll::Ready(None); + } + + // Have reached the end of all input streams + if self.current_stream_idx >= self.input_streams.len() { + return Poll::Ready(None); + } + + // Get next record batch + let mut poll; + loop { + let idx = self.current_stream_idx; + poll = self.input_streams[idx].poll_next_unpin(cx); + match poll { + // This input stream no longer has data, move to next stream + Poll::Ready(None) => { + self.current_stream_idx += 1; + if self.current_stream_idx >= self.input_streams.len() { + break; + } + } + _ => break, + } + } + + let poll = match ready!(poll) { + // This input stream has data, return its next record batch + Some(Ok(batch)) => { + self.produced += batch.num_rows(); + Poll::Ready(Some(Ok(batch))) + } + // This input stream has an error, return the error and set aborted to true to stop polling next round + Some(Err(e)) => { + self.aborted = true; + Poll::Ready(Some(Err(e))) + } + // This input stream has no more data, return None (aka finished) + None => { + // Reaching here means data of all streams have read + assert!( + self.current_stream_idx >= self.input_streams.len(), + "ProgressiveEvalStream::poll_next should not return None before all input streams are read",); + + Poll::Ready(None) + } + }; + + self.metrics.record_poll(poll) + } +} + +impl RecordBatchStream for ProgressiveEvalStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +// todo: this is a copy from DF code. When this ProgressiveEval operator is moved to DF, this can be removed +/// If running in a tokio context spawns the execution of `stream` to a separate task +/// allowing it to execute in parallel with an intermediate buffer of size `buffer` +pub(crate) fn spawn_buffered( + mut input: SendableRecordBatchStream, + buffer: usize, +) -> SendableRecordBatchStream { + // Use tokio only if running from a multi-thread tokio context + match tokio::runtime::Handle::try_current() { + Ok(handle) if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread => { + let mut builder = RecordBatchReceiverStream::builder(input.schema(), buffer); + + let sender = builder.tx(); + + builder.spawn(async move { + while let Some(item) = input.next().await { + if sender.send(item).await.is_err() { + // receiver dropped when query is shutdown early (e.g., limit) or error, + // no need to return propagate the send error. + return Ok(()); + } + } + + Ok(()) + }); + + builder.build() + } + _ => input, + } +} + +#[cfg(test)] +mod tests { + use std::iter::FromIterator; + use std::sync::Weak; + + use arrow::array::ArrayRef; + use arrow::array::{Int32Array, StringArray, TimestampNanosecondArray}; + use arrow::datatypes::Schema; + use arrow::datatypes::{DataType, Field}; + use arrow::record_batch::RecordBatch; + use datafusion::assert_batches_eq; + use datafusion::physical_plan::collect; + use datafusion::physical_plan::memory::MemoryExec; + use datafusion::physical_plan::metrics::{MetricValue, Timestamp}; + use futures::{Future, FutureExt}; + + use super::*; + + #[tokio::test] + async fn test_no_input_stream() { + let task_ctx = Arc::new(TaskContext::default()); + _test_progressive_eval( + &[], + None, + None, // no fetch limit --> return all rows + &["++", "++"], + task_ctx, + ) + .await; + } + + #[tokio::test] + async fn test_one_input_stream() { + let task_ctx = Arc::new(TaskContext::default()); + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("a"), + Some("c"), + Some("e"), + Some("g"), + Some("j"), + ])); + let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8])); + let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + // return all + _test_progressive_eval( + &[vec![b1.clone()]], + None, + None, // no fetch limit --> return all rows + &[ + "+---+---+-------------------------------+", + "| a | b | c |", + "+---+---+-------------------------------+", + "| 1 | a | 1970-01-01T00:00:00.000000008 |", + "| 2 | c | 1970-01-01T00:00:00.000000007 |", + "| 7 | e | 1970-01-01T00:00:00.000000006 |", + "| 9 | g | 1970-01-01T00:00:00.000000005 |", + "| 3 | j | 1970-01-01T00:00:00.000000008 |", + "+---+---+-------------------------------+", + ], + Arc::clone(&task_ctx), + ) + .await; + + // fetch no rows + _test_progressive_eval( + &[vec![b1.clone()]], + None, + Some(0), + &["++", "++"], + Arc::clone(&task_ctx), + ) + .await; + + // still return all even select 3 rows becasue first record batch is returned + _test_progressive_eval( + &[vec![b1.clone()]], + None, + Some(3), + &[ + "+---+---+-------------------------------+", + "| a | b | c |", + "+---+---+-------------------------------+", + "| 1 | a | 1970-01-01T00:00:00.000000008 |", + "| 2 | c | 1970-01-01T00:00:00.000000007 |", + "| 7 | e | 1970-01-01T00:00:00.000000006 |", + "| 9 | g | 1970-01-01T00:00:00.000000005 |", + "| 3 | j | 1970-01-01T00:00:00.000000008 |", + "+---+---+-------------------------------+", + ], + Arc::clone(&task_ctx), + ) + .await; + + // return all because fetch limit is larger + _test_progressive_eval( + &[vec![b1.clone()]], + None, + Some(7), + &[ + "+---+---+-------------------------------+", + "| a | b | c |", + "+---+---+-------------------------------+", + "| 1 | a | 1970-01-01T00:00:00.000000008 |", + "| 2 | c | 1970-01-01T00:00:00.000000007 |", + "| 7 | e | 1970-01-01T00:00:00.000000006 |", + "| 9 | g | 1970-01-01T00:00:00.000000005 |", + "| 3 | j | 1970-01-01T00:00:00.000000008 |", + "+---+---+-------------------------------+", + ], + Arc::clone(&task_ctx), + ) + .await; + } + + #[tokio::test] + async fn test_return_all() { + let task_ctx = Arc::new(TaskContext::default()); + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("a"), + Some("c"), + Some("e"), + Some("g"), + Some("j"), + ])); + let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8])); + let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("b"), + Some("d"), + Some("f"), + Some("h"), + Some("j"), + ])); + let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6])); + let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + // [b1, b2] + _test_progressive_eval( + &[vec![b1.clone()], vec![b2.clone()]], + None, + None, // no fetch limit --> return all rows + &[ + "+----+---+-------------------------------+", + "| a | b | c |", + "+----+---+-------------------------------+", + "| 1 | a | 1970-01-01T00:00:00.000000008 |", + "| 2 | c | 1970-01-01T00:00:00.000000007 |", + "| 7 | e | 1970-01-01T00:00:00.000000006 |", + "| 9 | g | 1970-01-01T00:00:00.000000005 |", + "| 3 | j | 1970-01-01T00:00:00.000000008 |", + "| 10 | b | 1970-01-01T00:00:00.000000004 |", + "| 20 | d | 1970-01-01T00:00:00.000000006 |", + "| 70 | f | 1970-01-01T00:00:00.000000002 |", + "| 90 | h | 1970-01-01T00:00:00.000000002 |", + "| 30 | j | 1970-01-01T00:00:00.000000006 |", + "+----+---+-------------------------------+", + ], + Arc::clone(&task_ctx), + ) + .await; + + // [b2, b1] + _test_progressive_eval( + &[vec![b2], vec![b1]], + None, + None, + &[ + "+----+---+-------------------------------+", + "| a | b | c |", + "+----+---+-------------------------------+", + "| 10 | b | 1970-01-01T00:00:00.000000004 |", + "| 20 | d | 1970-01-01T00:00:00.000000006 |", + "| 70 | f | 1970-01-01T00:00:00.000000002 |", + "| 90 | h | 1970-01-01T00:00:00.000000002 |", + "| 30 | j | 1970-01-01T00:00:00.000000006 |", + "| 1 | a | 1970-01-01T00:00:00.000000008 |", + "| 2 | c | 1970-01-01T00:00:00.000000007 |", + "| 7 | e | 1970-01-01T00:00:00.000000006 |", + "| 9 | g | 1970-01-01T00:00:00.000000005 |", + "| 3 | j | 1970-01-01T00:00:00.000000008 |", + "+----+---+-------------------------------+", + ], + task_ctx, + ) + .await; + } + + #[tokio::test] + async fn test_return_all_on_different_length_batches() { + let task_ctx = Arc::new(TaskContext::default()); + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("e"), + ])); + let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8])); + let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + let a: ArrayRef = Arc::new(Int32Array::from(vec![70, 90, 30])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("c"), + Some("d"), + Some("e"), + ])); + let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2])); + let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + // [b1, b2] + _test_progressive_eval( + &[vec![b1.clone()], vec![b2.clone()]], + None, + None, + &[ + "+----+---+-------------------------------+", + "| a | b | c |", + "+----+---+-------------------------------+", + "| 1 | a | 1970-01-01T00:00:00.000000008 |", + "| 2 | b | 1970-01-01T00:00:00.000000007 |", + "| 7 | c | 1970-01-01T00:00:00.000000006 |", + "| 9 | d | 1970-01-01T00:00:00.000000005 |", + "| 3 | e | 1970-01-01T00:00:00.000000008 |", + "| 70 | c | 1970-01-01T00:00:00.000000004 |", + "| 90 | d | 1970-01-01T00:00:00.000000006 |", + "| 30 | e | 1970-01-01T00:00:00.000000002 |", + "+----+---+-------------------------------+", + ], + Arc::clone(&task_ctx), + ) + .await; + + // [b2, b1] + _test_progressive_eval( + &[vec![b2], vec![b1]], + None, + None, + &[ + "+----+---+-------------------------------+", + "| a | b | c |", + "+----+---+-------------------------------+", + "| 70 | c | 1970-01-01T00:00:00.000000004 |", + "| 90 | d | 1970-01-01T00:00:00.000000006 |", + "| 30 | e | 1970-01-01T00:00:00.000000002 |", + "| 1 | a | 1970-01-01T00:00:00.000000008 |", + "| 2 | b | 1970-01-01T00:00:00.000000007 |", + "| 7 | c | 1970-01-01T00:00:00.000000006 |", + "| 9 | d | 1970-01-01T00:00:00.000000005 |", + "| 3 | e | 1970-01-01T00:00:00.000000008 |", + "+----+---+-------------------------------+", + ], + task_ctx, + ) + .await; + } + + #[tokio::test] + async fn test_fetch_limit_1() { + let task_ctx = Arc::new(TaskContext::default()); + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("e"), + ])); + let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8])); + let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + let a: ArrayRef = Arc::new(Int32Array::from(vec![70, 90, 30])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("c"), + Some("d"), + Some("e"), + ])); + let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2])); + let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + // [b2, b1] + // b2 has 3 rows. b1 has 5 rows + // Fetch limit is 1 --> return all 3 rows of the first batch (b2) that covers that limit + _test_progressive_eval( + &[vec![b2.clone()], vec![b1.clone()]], + None, + Some(1), + &[ + "+----+---+-------------------------------+", + "| a | b | c |", + "+----+---+-------------------------------+", + "| 70 | c | 1970-01-01T00:00:00.000000004 |", + "| 90 | d | 1970-01-01T00:00:00.000000006 |", + "| 30 | e | 1970-01-01T00:00:00.000000002 |", + "+----+---+-------------------------------+", + ], + Arc::clone(&task_ctx), + ) + .await; + + // [b1, b2] + // b1 has 5 rows. b2 has 3 rows + // Fetch limit is 1 --> return all 5 rows of the first batch (b1) that covers that limit + _test_progressive_eval( + &[vec![b1], vec![b2]], + None, + Some(1), + &[ + "+---+---+-------------------------------+", + "| a | b | c |", + "+---+---+-------------------------------+", + "| 1 | a | 1970-01-01T00:00:00.000000008 |", + "| 2 | b | 1970-01-01T00:00:00.000000007 |", + "| 7 | c | 1970-01-01T00:00:00.000000006 |", + "| 9 | d | 1970-01-01T00:00:00.000000005 |", + "| 3 | e | 1970-01-01T00:00:00.000000008 |", + "+---+---+-------------------------------+", + ], + task_ctx, + ) + .await; + } + + #[tokio::test] + async fn test_fetch_limit_equal_first_batch_size() { + let task_ctx = Arc::new(TaskContext::default()); + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("e"), + ])); + let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8])); + let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + let a: ArrayRef = Arc::new(Int32Array::from(vec![70, 90, 30])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("c"), + Some("d"), + Some("e"), + ])); + let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2])); + let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + // [b2, b1] + // b2 has 3 rows. b1 has 5 rows + // Fetch limit is 3 --> return all 3 rows of the first batch (b2) that covers that limit + _test_progressive_eval( + &[vec![b2.clone()], vec![b1.clone()]], + None, + Some(3), + &[ + "+----+---+-------------------------------+", + "| a | b | c |", + "+----+---+-------------------------------+", + "| 70 | c | 1970-01-01T00:00:00.000000004 |", + "| 90 | d | 1970-01-01T00:00:00.000000006 |", + "| 30 | e | 1970-01-01T00:00:00.000000002 |", + "+----+---+-------------------------------+", + ], + Arc::clone(&task_ctx), + ) + .await; + + // [b1, b2] + // b1 has 5 rows. b2 has 3 rows + // Fetch limit is 5 --> return all 5 rows of first batch (b1) that covers that limit + _test_progressive_eval( + &[vec![b1], vec![b2]], + None, + Some(5), + &[ + "+---+---+-------------------------------+", + "| a | b | c |", + "+---+---+-------------------------------+", + "| 1 | a | 1970-01-01T00:00:00.000000008 |", + "| 2 | b | 1970-01-01T00:00:00.000000007 |", + "| 7 | c | 1970-01-01T00:00:00.000000006 |", + "| 9 | d | 1970-01-01T00:00:00.000000005 |", + "| 3 | e | 1970-01-01T00:00:00.000000008 |", + "+---+---+-------------------------------+", + ], + task_ctx, + ) + .await; + } + + #[tokio::test] + async fn test_fetch_limit_over_first_batch_size() { + let task_ctx = Arc::new(TaskContext::default()); + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("e"), + ])); + let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8])); + let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + let a: ArrayRef = Arc::new(Int32Array::from(vec![70, 90, 30])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("c"), + Some("d"), + Some("e"), + ])); + let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2])); + let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + // [b2, b1] + // b2 has 3 rows. b1 has 5 rows + // Fetch limit is 4 --> return all rows of both batches in the order of b2, b1 + _test_progressive_eval( + &[vec![b2.clone()], vec![b1.clone()]], + None, + Some(4), + &[ + "+----+---+-------------------------------+", + "| a | b | c |", + "+----+---+-------------------------------+", + "| 70 | c | 1970-01-01T00:00:00.000000004 |", + "| 90 | d | 1970-01-01T00:00:00.000000006 |", + "| 30 | e | 1970-01-01T00:00:00.000000002 |", + "| 1 | a | 1970-01-01T00:00:00.000000008 |", + "| 2 | b | 1970-01-01T00:00:00.000000007 |", + "| 7 | c | 1970-01-01T00:00:00.000000006 |", + "| 9 | d | 1970-01-01T00:00:00.000000005 |", + "| 3 | e | 1970-01-01T00:00:00.000000008 |", + "+----+---+-------------------------------+", + ], + Arc::clone(&task_ctx), + ) + .await; + + // [b1, b2] + // b1 has 5 rows. b2 has 3 rows + // Fetch limit is 6 --> return all rows of both batches in the order of b1, b2 + _test_progressive_eval( + &[vec![b1], vec![b2]], + None, + Some(6), + &[ + "+----+---+-------------------------------+", + "| a | b | c |", + "+----+---+-------------------------------+", + "| 1 | a | 1970-01-01T00:00:00.000000008 |", + "| 2 | b | 1970-01-01T00:00:00.000000007 |", + "| 7 | c | 1970-01-01T00:00:00.000000006 |", + "| 9 | d | 1970-01-01T00:00:00.000000005 |", + "| 3 | e | 1970-01-01T00:00:00.000000008 |", + "| 70 | c | 1970-01-01T00:00:00.000000004 |", + "| 90 | d | 1970-01-01T00:00:00.000000006 |", + "| 30 | e | 1970-01-01T00:00:00.000000002 |", + "+----+---+-------------------------------+", + ], + task_ctx, + ) + .await; + } + + #[tokio::test] + async fn test_three_partitions_with_nulls() { + let task_ctx = Arc::new(TaskContext::default()); + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("a"), + Some("b"), + Some("c"), + None, + Some("f"), + ])); + let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8])); + let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("e"), + Some("g"), + Some("h"), + ])); + let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![40, 60, 20])); + let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + let a: ArrayRef = Arc::new(Int32Array::from(vec![100, 200, 700, 900])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + None, + Some("g"), + Some("h"), + Some("i"), + ])); + let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2])); + let b3 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + // [b1, b2, b3] + // b1 has 5 rows. b2 has 3 rows. b3 has 4 rows + // Fetch limit is 1 --> return all rows of the b1 + _test_progressive_eval( + &[vec![b1.clone()], vec![b2.clone()], vec![b3.clone()]], + None, + Some(1), + &[ + "+---+---+-------------------------------+", + "| a | b | c |", + "+---+---+-------------------------------+", + "| 1 | a | 1970-01-01T00:00:00.000000008 |", + "| 2 | b | 1970-01-01T00:00:00.000000007 |", + "| 7 | c | 1970-01-01T00:00:00.000000006 |", + "| 9 | | 1970-01-01T00:00:00.000000005 |", + "| 3 | f | 1970-01-01T00:00:00.000000008 |", + "+---+---+-------------------------------+", + ], + Arc::clone(&task_ctx), + ) + .await; + + // [b1, b2, b3] + // b1 has 5 rows. b2 has 3 rows. b3 has 4 rows + // Fetch limit is 7 --> return all rows of the b1 & b2 in the order of b1, b2 + _test_progressive_eval( + &[vec![b1.clone()], vec![b2.clone()], vec![b3.clone()]], + None, + Some(7), + &[ + "+----+---+-------------------------------+", + "| a | b | c |", + "+----+---+-------------------------------+", + "| 1 | a | 1970-01-01T00:00:00.000000008 |", + "| 2 | b | 1970-01-01T00:00:00.000000007 |", + "| 7 | c | 1970-01-01T00:00:00.000000006 |", + "| 9 | | 1970-01-01T00:00:00.000000005 |", + "| 3 | f | 1970-01-01T00:00:00.000000008 |", + "| 10 | e | 1970-01-01T00:00:00.000000040 |", + "| 20 | g | 1970-01-01T00:00:00.000000060 |", + "| 70 | h | 1970-01-01T00:00:00.000000020 |", + "+----+---+-------------------------------+", + ], + Arc::clone(&task_ctx), + ) + .await; + + // [b1, b2, b3] + // b1 has 5 rows. b2 has 3 rows. b3 has 4 rows + // Fetch limit is 50 --> return all rows of all batches in the order of b1, b2, b3 + _test_progressive_eval( + &[vec![b1], vec![b2], vec![b3]], + None, + Some(50), + &[ + "+-----+---+-------------------------------+", + "| a | b | c |", + "+-----+---+-------------------------------+", + "| 1 | a | 1970-01-01T00:00:00.000000008 |", + "| 2 | b | 1970-01-01T00:00:00.000000007 |", + "| 7 | c | 1970-01-01T00:00:00.000000006 |", + "| 9 | | 1970-01-01T00:00:00.000000005 |", + "| 3 | f | 1970-01-01T00:00:00.000000008 |", + "| 10 | e | 1970-01-01T00:00:00.000000040 |", + "| 20 | g | 1970-01-01T00:00:00.000000060 |", + "| 70 | h | 1970-01-01T00:00:00.000000020 |", + "| 100 | | 1970-01-01T00:00:00.000000004 |", + "| 200 | g | 1970-01-01T00:00:00.000000006 |", + "| 700 | h | 1970-01-01T00:00:00.000000002 |", + "| 900 | i | 1970-01-01T00:00:00.000000002 |", + "+-----+---+-------------------------------+", + ], + task_ctx, + ) + .await; + } + + async fn _test_progressive_eval( + partitions: &[Vec], + value_ranges: Option>, + fetch: Option, + exp: &[&str], + context: Arc, + ) { + let schema = if partitions.is_empty() { + // just whatwever schema + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); + let batch = RecordBatch::try_from_iter(vec![("a", a)]).unwrap(); + batch.schema() + } else { + partitions[0][0].schema() + }; + + let exec = MemoryExec::try_new(partitions, schema, None).unwrap(); + let progressive = Arc::new(ProgressiveEvalExec::new( + Arc::new(exec), + value_ranges, + fetch, + )); + + let collected = collect(progressive, context).await.unwrap(); + assert_batches_eq!(exp, collected.as_slice()); + } + + #[tokio::test] + async fn test_merge_metrics() { + let task_ctx = Arc::new(TaskContext::default()); + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")])); + let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); + + let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("b"), Some("d")])); + let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); + + let schema = b1.schema(); + let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap(); + let progressive = Arc::new(ProgressiveEvalExec::new(Arc::new(exec), None, None)); + + let collected = collect(Arc::::clone(&progressive), task_ctx) + .await + .unwrap(); + let expected = [ + "+----+---+", + "| a | b |", + "+----+---+", + "| 1 | a |", + "| 2 | c |", + "| 10 | b |", + "| 20 | d |", + "+----+---+", + ]; + assert_batches_eq!(expected, collected.as_slice()); + + // Now, validate metrics + let metrics = progressive.metrics().unwrap(); + + assert_eq!(metrics.output_rows().unwrap(), 4); + assert!(metrics.elapsed_compute().unwrap() > 0); + + let mut saw_start = false; + let mut saw_end = false; + metrics.iter().for_each(|m| match m.value() { + MetricValue::StartTimestamp(ts) => { + saw_start = true; + assert!(nanos_from_timestamp(ts) > 0); + } + MetricValue::EndTimestamp(ts) => { + saw_end = true; + assert!(nanos_from_timestamp(ts) > 0); + } + _ => {} + }); + + assert!(saw_start); + assert!(saw_end); + } + + fn nanos_from_timestamp(ts: &Timestamp) -> i64 { + ts.value().unwrap().timestamp_nanos_opt().unwrap() + } + + #[tokio::test] + async fn test_drop_cancel() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + + let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2)); + let refs = blocking_exec.refs(); + let progressive_exec = Arc::new(ProgressiveEvalExec::new(blocking_exec, None, None)); + + let fut = collect(progressive_exec, task_ctx); + let mut fut = fut.boxed(); + + assert_is_pending(&mut fut); + drop(fut); + assert_strong_count_converges_to_zero(refs).await; + + Ok(()) + } + + // todo: this is copied from DF. When we move ProgressiveEval to DF, this will be removed + /// Asserts that the strong count of the given [`Weak`] pointer converges to zero. + /// + /// This might take a while but has a timeout. + pub async fn assert_strong_count_converges_to_zero(refs: Weak) { + #![allow(clippy::future_not_send)] + tokio::time::timeout(std::time::Duration::from_secs(10), async { + loop { + if Weak::strong_count(&refs) == 0 { + break; + } + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + }) + .await + .unwrap(); + } + + // todo: this is copied from DF. When we move ProgressiveEval to DF, this will be removed + /// Asserts that given future is pending. + pub fn assert_is_pending<'a, T>(fut: &mut Pin + Send + 'a>>) { + let waker = futures::task::noop_waker(); + let mut cx = futures::task::Context::from_waker(&waker); + let poll = fut.poll_unpin(&mut cx); + + assert!(poll.is_pending()); + } + + // todo: this is copied from DF. When we move ProgressiveEval to DF, this will be removed + /// Execution plan that emits streams that block forever. + /// + /// This is useful to test shutdown / cancelation behavior of certain execution plans. + #[derive(Debug)] + pub struct BlockingExec { + /// Schema that is mocked by this plan. + schema: SchemaRef, + + /// Number of output partitions. + n_partitions: usize, + + /// Ref-counting helper to check if the plan and the produced stream are still in memory. + refs: Arc<()>, + } + + impl BlockingExec { + /// Create new [`BlockingExec`] with a give schema and number of partitions. + pub fn new(schema: SchemaRef, n_partitions: usize) -> Self { + Self { + schema, + n_partitions, + refs: Default::default(), + } + } + + /// Weak pointer that can be used for ref-counting this execution plan and its streams. + /// + /// Use [`Weak::strong_count`] to determine if the plan itself and its streams are dropped (should be 0 in that + /// case). Note that tokio might take some time to cancel spawned tasks, so you need to wrap this check into a retry + /// loop. Use [`assert_strong_count_converges_to_zero`] to archive this. + pub fn refs(&self) -> Weak<()> { + Arc::downgrade(&self.refs) + } + } + + impl DisplayAs for BlockingExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "BlockingExec",) + } + } + } + } + + impl ExecutionPlan for BlockingExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec> { + // this is a leaf node and has no children + vec![] + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(self.n_partitions) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + internal_err!("Children cannot be replaced in {self:?}") + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + Ok(Box::pin(BlockingStream { + schema: Arc::clone(&self.schema), + _refs: Arc::clone(&self.refs), + })) + } + + fn statistics(&self) -> Result { + unimplemented!() + } + } + + /// A [`RecordBatchStream`] that is pending forever. + #[derive(Debug)] + pub struct BlockingStream { + /// Schema mocked by this stream. + schema: SchemaRef, + + /// Ref-counting helper to check if the stream are still in memory. + _refs: Arc<()>, + } + + impl Stream for BlockingStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Pending + } + } + + impl RecordBatchStream for BlockingStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + } +} diff --git a/iox_query/src/provider/record_batch_exec.rs b/iox_query/src/provider/record_batch_exec.rs new file mode 100644 index 0000000..6122286 --- /dev/null +++ b/iox_query/src/provider/record_batch_exec.rs @@ -0,0 +1,191 @@ +//! Implementation of a DataFusion PhysicalPlan node across partition chunks + +use crate::statistics::build_statistics_for_chunks; +use crate::{QueryChunk, CHUNK_ORDER_COLUMN_NAME}; + +use super::adapter::SchemaAdapterStream; +use arrow::datatypes::SchemaRef; +use datafusion::physical_plan::display::ProjectSchemaDisplay; +use datafusion::{ + error::DataFusionError, + execution::context::TaskContext, + physical_plan::{ + expressions::{Column, PhysicalSortExpr}, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + Statistics, + }, + scalar::ScalarValue, +}; +use observability_deps::tracing::trace; +use schema::sort::SortKey; +use std::{collections::HashMap, fmt, sync::Arc}; + +/// Implements the DataFusion physical plan interface for [`RecordBatch`]es with automatic projection and NULL-column creation. +/// +/// +/// [`RecordBatch`]: arrow::record_batch::RecordBatch +#[derive(Debug)] +pub(crate) struct RecordBatchesExec { + /// Chunks contained in this exec node. + chunks: Vec>, + + /// Overall schema. + schema: SchemaRef, + + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + + /// Statistics over all batches. + statistics: Statistics, + + /// Sort key that was passed to [`chunks_to_physical_nodes`]. + /// + /// This is NOT used to set the output ordering. It is only here to recover this information later. + /// + /// + /// [`chunks_to_physical_nodes`]: super::physical::chunks_to_physical_nodes + output_sort_key_memo: Option, + + /// Output ordering. + output_ordering: Option>, +} + +impl RecordBatchesExec { + pub fn new( + chunks: impl IntoIterator>, + schema: SchemaRef, + output_sort_key_memo: Option, + ) -> Self { + let chunks: Vec<_> = chunks.into_iter().collect(); + let statistics = build_statistics_for_chunks(&chunks, Arc::clone(&schema)); + + let chunk_order_field = schema.field_with_name(CHUNK_ORDER_COLUMN_NAME).ok(); + let output_ordering = if chunk_order_field.is_some() { + Some(vec![ + // every chunk gets its own partition, so we can claim that the output is ordered + PhysicalSortExpr { + expr: Arc::new( + Column::new_with_schema(CHUNK_ORDER_COLUMN_NAME, &schema) + .expect("just checked presence of chunk order col"), + ), + options: Default::default(), + }, + ]) + } else { + None + }; + + Self { + chunks, + schema, + statistics, + output_sort_key_memo, + output_ordering, + metrics: ExecutionPlanMetricsSet::new(), + } + } + + /// Chunks that make up this node. + pub fn chunks(&self) -> impl Iterator> { + self.chunks.iter() + } + + /// Sort key that was passed to [`chunks_to_physical_nodes`]. + /// + /// This is NOT used to set the output ordering. It is only here to recover this information later. + /// + /// + /// [`chunks_to_physical_nodes`]: super::physical::chunks_to_physical_nodes + pub fn output_sort_key_memo(&self) -> Option<&SortKey> { + self.output_sort_key_memo.as_ref() + } +} + +impl ExecutionPlan for RecordBatchesExec { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(self.chunks.len()) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.output_ordering.as_deref() + } + + fn children(&self) -> Vec> { + // no inputs + vec![] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion::error::Result> { + assert!(children.is_empty(), "no children expected in iox plan"); + + Ok(self) + } + + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> datafusion::error::Result { + trace!(partition, "Start RecordBatchesExec::execute"); + + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + + let schema = self.schema(); + + let chunk = &self.chunks[partition]; + + let stream = match chunk.data() { + crate::QueryChunkData::RecordBatches(stream) => stream, + crate::QueryChunkData::Parquet(_) => { + return Err(DataFusionError::Execution(String::from( + "chunk must contain record batches", + ))); + } + }; + let virtual_columns = HashMap::from([( + CHUNK_ORDER_COLUMN_NAME, + ScalarValue::from(chunk.order().get()), + )]); + let adapter = Box::pin( + SchemaAdapterStream::try_new(stream, schema, &virtual_columns, baseline_metrics) + .map_err(|e| DataFusionError::External(Box::new(e)))?, + ); + + trace!(partition, "End RecordBatchesExec::execute"); + Ok(adapter) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + Ok(self.statistics.clone()) + } +} + +impl DisplayAs for RecordBatchesExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "RecordBatchesExec: chunks={}", self.chunks.len(),)?; + if !self.schema.fields().is_empty() { + write!(f, ", projection={}", ProjectSchemaDisplay(&self.schema))?; + } + Ok(()) + } + } + } +} diff --git a/iox_query/src/pruning.rs b/iox_query/src/pruning.rs new file mode 100644 index 0000000..50f44f1 --- /dev/null +++ b/iox_query/src/pruning.rs @@ -0,0 +1,689 @@ +//! Implementation of statistics based pruning + +use crate::QueryChunk; +use arrow::{ + array::{ArrayRef, BooleanArray, UInt64Array}, + datatypes::{DataType, SchemaRef}, +}; +use datafusion::{ + physical_expr::execution_props::ExecutionProps, + physical_optimizer::pruning::PruningStatistics, + physical_plan::{ColumnStatistics, Statistics}, + prelude::{col, Column, Expr}, + scalar::ScalarValue, +}; +use datafusion_util::{create_pruning_predicate, lit_timestamptz_nano}; +use observability_deps::tracing::{debug, trace, warn}; +use query_functions::group_by::Aggregate; +use schema::{Schema, TIME_COLUMN_NAME}; +use std::collections::HashSet; +use std::sync::Arc; + +/// Reason why a chunk could not be pruned. +/// +/// Also see [`PruningObserver::could_not_prune`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum NotPrunedReason { + /// No expression on predicate + NoExpressionOnPredicate, + + /// Can not create pruning predicate + CanNotCreatePruningPredicate, + + /// DataFusion pruning failed + DataFusionPruningFailed, +} + +impl NotPrunedReason { + /// Human-readable string representation. + pub fn name(&self) -> &'static str { + match self { + Self::NoExpressionOnPredicate => "No expression on predicate", + Self::CanNotCreatePruningPredicate => "Can not create pruning predicate", + Self::DataFusionPruningFailed => "DataFusion pruning failed", + } + } +} + +impl std::fmt::Display for NotPrunedReason { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.name()) + } +} + +/// Something that cares to be notified when pruning of chunks occurs +pub trait PruningObserver { + /// Called when the specified chunk was pruned + fn was_pruned(&self, _chunk: &dyn QueryChunk) {} + + /// Called when a chunk was not pruned. + fn was_not_pruned(&self, _chunk: &dyn QueryChunk) {} + + /// Called when no pruning can happen at all for some reason. + /// + /// Since pruning is optional and _only_ improves performance but its lack does not affect correctness, this will + /// NOT lead to a query error. + /// + /// In this case, statistical pruning will not happen and neither [`was_pruned`](Self::was_pruned) nor + /// [`was_not_pruned`](Self::was_not_pruned) will be called. + fn could_not_prune(&self, _reason: NotPrunedReason, _chunk: &dyn QueryChunk) {} +} + +/// Given a Vec of prunable items, returns a possibly smaller set +/// filtering those where the predicate can be proven to evaluate to +/// `false` for every single row. +pub fn prune_chunks( + table_schema: &Schema, + chunks: &[Arc], + filters: &[Expr], +) -> Result, NotPrunedReason> { + let num_chunks = chunks.len(); + debug!(num_chunks, ?filters, "Pruning chunks"); + let summaries: Vec<_> = chunks + .iter() + .map(|c| (c.stats(), c.schema().as_arrow())) + .collect(); + + let filter_expr = match filters.iter().cloned().reduce(|a, b| a.and(b)) { + Some(expr) => expr, + None => { + debug!("No expression on predicate"); + return Err(NotPrunedReason::NoExpressionOnPredicate); + } + }; + + prune_summaries(table_schema, &summaries, &filter_expr) +} + +/// Given a `Vec` of pruning summaries, return a `Vec` where `false` indicates that the +/// predicate can be proven to evaluate to `false` for every single row. +pub fn prune_summaries( + table_schema: &Schema, + summaries: &[(Arc, SchemaRef)], + filter_expr: &Expr, +) -> Result, NotPrunedReason> { + trace!(%filter_expr, "Filter_expr of pruning chunks"); + + // no information about the queries here + let props = ExecutionProps::new(); + let pruning_predicate = + match create_pruning_predicate(&props, filter_expr, &table_schema.as_arrow()) { + Ok(p) => p, + Err(e) => { + warn!(%e, ?filter_expr, "Can not create pruning predicate"); + return Err(NotPrunedReason::CanNotCreatePruningPredicate); + } + }; + + let statistics = ChunkPruningStatistics { + table_schema, + summaries, + }; + + let results = match pruning_predicate.prune(&statistics) { + Ok(results) => results, + Err(e) => { + warn!(%e, ?filter_expr, "DataFusion pruning failed"); + return Err(NotPrunedReason::DataFusionPruningFailed); + } + }; + Ok(results) +} + +/// Wraps a collection of [`QueryChunk`] and implements the [`PruningStatistics`] +/// interface required for pruning +struct ChunkPruningStatistics<'a> { + table_schema: &'a Schema, + summaries: &'a [(Arc, SchemaRef)], +} + +impl<'a> ChunkPruningStatistics<'a> { + /// Returns the [`DataType`] for `column` + fn column_type(&self, column: &Column) -> Option<&DataType> { + let index = self.table_schema.find_index_of(&column.name)?; + Some(self.table_schema.field(index).1.data_type()) + } + + /// Returns an iterator that for each chunk returns the [`Statistics`] + /// for the provided `column` if any + fn column_summaries<'b: 'a, 'c: 'a>( + &'c self, + column: &'b Column, + ) -> impl Iterator> + 'a { + self.summaries.iter().map(|(stats, schema)| { + let idx = schema.index_of(&column.name).ok()?; + Some(&stats.column_statistics[idx]) + }) + } +} + +impl<'a> PruningStatistics for ChunkPruningStatistics<'a> { + fn min_values(&self, column: &Column) -> Option { + let data_type = self.column_type(column)?; + let summaries = self.column_summaries(column); + collect_pruning_stats(data_type, summaries, Aggregate::Min) + } + + fn max_values(&self, column: &Column) -> Option { + let data_type = self.column_type(column)?; + let summaries = self.column_summaries(column); + collect_pruning_stats(data_type, summaries, Aggregate::Max) + } + + fn num_containers(&self) -> usize { + self.summaries.len() + } + + fn null_counts(&self, column: &Column) -> Option { + let null_counts = self + .column_summaries(column) + .map(|stats| stats.and_then(|stats| stats.null_count.get_value())) + .map(|x| x.map(|x| *x as u64)); + + Some(Arc::new(UInt64Array::from_iter(null_counts))) + } + + fn contained( + &self, + _column: &datafusion::common::Column, + _values: &HashSet, + ) -> Option { + None + } +} + +/// Collects an [`ArrayRef`] containing the aggregate statistic corresponding to +/// `aggregate` for each of the provided [`Statistics`] +fn collect_pruning_stats<'a>( + data_type: &DataType, + statistics: impl Iterator>, + aggregate: Aggregate, +) -> Option { + let null = ScalarValue::try_from(data_type).ok()?; + + ScalarValue::iter_to_array(statistics.map(|stats| { + stats + .and_then(|stats| get_aggregate(stats, aggregate).cloned()) + .unwrap_or_else(|| null.clone()) + })) + .ok() +} + +/// Returns the aggregate statistic corresponding to `aggregate` from `stats` +fn get_aggregate(stats: &ColumnStatistics, aggregate: Aggregate) -> Option<&ScalarValue> { + match aggregate { + Aggregate::Min => stats.min_value.get_value(), + Aggregate::Max => stats.max_value.get_value(), + _ => None, + } +} + +/// Retention time expression, "time > retention_time". +pub fn retention_expr(retention_time: i64) -> Expr { + col(TIME_COLUMN_NAME).gt(lit_timestamptz_nano(retention_time)) +} + +#[cfg(test)] +mod test { + use std::{ops::Not, sync::Arc}; + + use datafusion::prelude::{col, lit}; + use datafusion_util::lit_dict; + use schema::merge::SchemaMerger; + + use crate::{test::TestChunk, QueryChunk}; + + use super::*; + + #[test] + fn test_empty() { + test_helpers::maybe_start_logging(); + let c1 = Arc::new(TestChunk::new("chunk1")); + + let result = prune_chunks(&c1.schema().clone(), &[c1], &[]); + + assert_eq!(result, Err(NotPrunedReason::NoExpressionOnPredicate)); + } + + #[test] + fn test_pruned_f64() { + test_helpers::maybe_start_logging(); + // column1 > 100.0 where + // c1: [0.0, 10.0] --> pruned + let c1 = Arc::new(TestChunk::new("chunk1").with_f64_field_column_with_stats( + "column1", + Some(0.0), + Some(10.0), + )); + + let filters = vec![col("column1").gt(lit(100.0f64))]; + + let result = prune_chunks(&c1.schema().clone(), &[c1], &filters); + assert_eq!(result.expect("pruning succeeds"), vec![false]); + } + + #[test] + fn test_pruned_i64() { + test_helpers::maybe_start_logging(); + // column1 > 100 where + // c1: [0, 10] --> pruned + + let c1 = Arc::new(TestChunk::new("chunk1").with_i64_field_column_with_stats( + "column1", + Some(0), + Some(10), + )); + + let filters = vec![col("column1").gt(lit(100i64))]; + + let result = prune_chunks(&c1.schema().clone(), &[c1], &filters); + + assert_eq!(result.expect("pruning succeeds"), vec![false]); + } + + #[test] + fn test_pruned_u64() { + test_helpers::maybe_start_logging(); + // column1 > 100 where + // c1: [0, 10] --> pruned + + let c1 = Arc::new(TestChunk::new("chunk1").with_u64_field_column_with_stats( + "column1", + Some(0), + Some(10), + )); + + let filters = vec![col("column1").gt(lit(100u64))]; + + let result = prune_chunks(&c1.schema().clone(), &[c1], &filters); + assert_eq!(result.expect("pruning succeeds"), vec![false]); + } + + #[test] + fn test_pruned_bool() { + test_helpers::maybe_start_logging(); + // column1 where + // c1: [false, false] --> pruned + let c1 = Arc::new(TestChunk::new("chunk1").with_bool_field_column_with_stats( + "column1", + Some(false), + Some(false), + )); + + let filters = vec![col("column1")]; + + let result = prune_chunks(&c1.schema().clone(), &[c1], &filters); + assert_eq!(result.expect("pruning succeeds"), vec![false; 1]); + } + + #[test] + fn test_pruned_string() { + test_helpers::maybe_start_logging(); + // column1 > "z" where + // c1: ["a", "q"] --> pruned + + let c1 = Arc::new( + TestChunk::new("chunk1").with_string_field_column_with_stats( + "column1", + Some("a"), + Some("q"), + ), + ); + + let filters = vec![col("column1").gt(lit("z"))]; + + let result = prune_chunks(&c1.schema().clone(), &[c1], &filters); + assert_eq!(result.expect("pruning succeeds"), vec![false]); + } + + #[test] + fn test_not_pruned_f64() { + test_helpers::maybe_start_logging(); + // column1 < 100.0 where + // c1: [0.0, 10.0] --> not pruned + let c1 = Arc::new(TestChunk::new("chunk1").with_f64_field_column_with_stats( + "column1", + Some(0.0), + Some(10.0), + )); + + let filters = vec![col("column1").lt(lit(100.0f64))]; + + let result = prune_chunks(&c1.schema().clone(), &[c1], &filters); + assert_eq!(result.expect("pruning succeeds"), vec![true]); + } + + #[test] + fn test_not_pruned_i64() { + test_helpers::maybe_start_logging(); + // column1 < 100 where + // c1: [0, 10] --> not pruned + + let c1 = Arc::new(TestChunk::new("chunk1").with_i64_field_column_with_stats( + "column1", + Some(0), + Some(10), + )); + + let filters = vec![col("column1").lt(lit(100i64))]; + + let result = prune_chunks(&c1.schema().clone(), &[c1], &filters); + assert_eq!(result.expect("pruning succeeds"), vec![true]); + } + + #[test] + fn test_not_pruned_u64() { + test_helpers::maybe_start_logging(); + // column1 < 100 where + // c1: [0, 10] --> not pruned + + let c1 = Arc::new(TestChunk::new("chunk1").with_u64_field_column_with_stats( + "column1", + Some(0), + Some(10), + )); + + let filters = vec![col("column1").lt(lit(100u64))]; + + let result = prune_chunks(&c1.schema().clone(), &[c1], &filters); + assert_eq!(result.expect("pruning succeeds"), vec![true]); + } + + #[test] + fn test_not_pruned_bool() { + test_helpers::maybe_start_logging(); + // column1 + // c1: [false, true] --> not pruned + + let c1 = Arc::new(TestChunk::new("chunk1").with_bool_field_column_with_stats( + "column1", + Some(false), + Some(true), + )); + + let filters = vec![col("column1")]; + + let result = prune_chunks(&c1.schema().clone(), &[c1], &filters); + assert_eq!(result.expect("pruning succeeds"), vec![true]); + } + + #[test] + fn test_not_pruned_string() { + test_helpers::maybe_start_logging(); + // column1 < "z" where + // c1: ["a", "q"] --> not pruned + + let c1 = Arc::new( + TestChunk::new("chunk1").with_string_field_column_with_stats( + "column1", + Some("a"), + Some("q"), + ), + ); + + let filters = vec![col("column1").lt(lit("z"))]; + + let result = prune_chunks(&c1.schema().clone(), &[c1], &filters); + assert_eq!(result.expect("pruning succeeds"), vec![true]); + } + + fn merge_schema(chunks: &[Arc]) -> Schema { + let mut merger = SchemaMerger::new(); + for chunk in chunks { + merger = merger.merge(chunk.schema()).unwrap(); + } + merger.build() + } + + #[test] + fn test_pruned_null() { + test_helpers::maybe_start_logging(); + // column1 > 100 where + // c1: [Null, 10] --> pruned + // c2: [0, Null] --> not pruned + // c3: [Null, Null] --> not pruned (min/max are not known in chunk 3) + // c4: Null --> not pruned (no statistics at all) + + let c1 = Arc::new(TestChunk::new("chunk1").with_i64_field_column_with_stats( + "column1", + None, + Some(10), + )) as Arc; + + let c2 = Arc::new(TestChunk::new("chunk2").with_i64_field_column_with_stats( + "column1", + Some(0), + None, + )) as Arc; + + let c3 = Arc::new( + TestChunk::new("chunk3").with_i64_field_column_with_stats("column1", None, None), + ) as Arc; + + let c4 = Arc::new(TestChunk::new("chunk4").with_i64_field_column("column1")) + as Arc; + + let filters = vec![col("column1").gt(lit(100i64))]; + + let chunks = vec![c1, c2, c3, c4]; + let schema = merge_schema(&chunks); + + let result = prune_chunks(&schema, &chunks, &filters); + + assert_eq!( + result.expect("pruning succeeds"), + vec![false, true, true, true] + ); + } + + #[test] + fn test_pruned_multi_chunk() { + test_helpers::maybe_start_logging(); + // column1 > 100 where + // c1: [0, 10] --> pruned + // c2: [0, 1000] --> not pruned + // c3: [10, 20] --> pruned + // c4: [None, None] --> not pruned + // c5: [10, None] --> not pruned + // c6: [None, 10] --> pruned + + let c1 = Arc::new(TestChunk::new("chunk1").with_i64_field_column_with_stats( + "column1", + Some(0), + Some(10), + )) as Arc; + + let c2 = Arc::new(TestChunk::new("chunk2").with_i64_field_column_with_stats( + "column1", + Some(0), + Some(1000), + )) as Arc; + + let c3 = Arc::new(TestChunk::new("chunk3").with_i64_field_column_with_stats( + "column1", + Some(10), + Some(20), + )) as Arc; + + let c4 = Arc::new( + TestChunk::new("chunk4").with_i64_field_column_with_stats("column1", None, None), + ) as Arc; + + let c5 = Arc::new(TestChunk::new("chunk5").with_i64_field_column_with_stats( + "column1", + Some(10), + None, + )) as Arc; + + let c6 = Arc::new(TestChunk::new("chunk6").with_i64_field_column_with_stats( + "column1", + None, + Some(20), + )) as Arc; + + let filters = vec![col("column1").gt(lit(100i64))]; + + let chunks = vec![c1, c2, c3, c4, c5, c6]; + let schema = merge_schema(&chunks); + + let result = prune_chunks(&schema, &chunks, &filters); + + assert_eq!( + result.expect("pruning succeeds"), + vec![false, true, false, true, true, false] + ); + } + + #[test] + fn test_pruned_different_schema() { + test_helpers::maybe_start_logging(); + // column1 > 100 where + // c1: column1 [0, 100], column2 [0, 4] --> pruned (in range, column2 ignored) + // c2: column1 [0, 1000], column2 [0, 4] --> not pruned (in range, column2 ignored) + // c3: None, column2 [0, 4] --> not pruned (no stats for column1) + let c1 = Arc::new( + TestChunk::new("chunk1") + .with_i64_field_column_with_stats("column1", Some(0), Some(100)) + .with_i64_field_column_with_stats("column2", Some(0), Some(4)), + ) as Arc; + + let c2 = Arc::new( + TestChunk::new("chunk2") + .with_i64_field_column_with_stats("column1", Some(0), Some(1000)) + .with_i64_field_column_with_stats("column2", Some(0), Some(4)), + ) as Arc; + + let c3 = Arc::new(TestChunk::new("chunk3").with_i64_field_column_with_stats( + "column2", + Some(0), + Some(4), + )) as Arc; + + let filters = vec![col("column1").gt(lit(100i64))]; + + let chunks = vec![c1, c2, c3]; + let schema = merge_schema(&chunks); + + let result = prune_chunks(&schema, &chunks, &filters); + + assert_eq!(result.expect("pruning succeeds"), vec![false, true, true]); + } + + #[test] + fn test_pruned_is_null() { + test_helpers::maybe_start_logging(); + // Verify that type of predicate is pruned if column1 is null + // (this is a common predicate type created by the INfluxRPC planner) + // (NOT column1 IS NULL) AND (column1 = 'bar') + + // No nulls, can't prune as it has values that are more and less than 'bar' + let c1 = Arc::new( + TestChunk::new("chunk1").with_tag_column_with_nulls_and_full_stats( + "column1", + Some("a"), + Some("z"), + 100, + None, + 0, + ), + ) as Arc; + + // Has no nulls, can prune it out based on statistics alone + let c2 = Arc::new( + TestChunk::new("chunk2").with_tag_column_with_nulls_and_full_stats( + "column1", + Some("a"), + Some("b"), + 100, + None, + 0, + ), + ) as Arc; + + // Has nulls, can still can prune it out based on statistics alone + let c3 = Arc::new( + TestChunk::new("chunk3").with_tag_column_with_nulls_and_full_stats( + "column1", + Some("a"), + Some("b"), + 100, + None, + 1, // that one peksy null! + ), + ) as Arc; + + let filters = vec![col("column1") + .is_null() + .not() + .and(col("column1").eq(lit_dict("bar")))]; + + let chunks = vec![c1, c2, c3]; + let schema = merge_schema(&chunks); + + let result = prune_chunks(&schema, &chunks, &filters); + + assert_eq!(result.expect("pruning succeeds"), vec![true, false, false]); + } + + #[test] + fn test_pruned_multi_column() { + test_helpers::maybe_start_logging(); + // column1 > 100 AND column2 < 5 where + // c1: column1 [0, 1000], column2 [0, 4] --> not pruned (both in range) + // c2: column1 [0, 10], column2 [0, 4] --> pruned (column1 and column2 out of range) + // c3: column1 [0, 10], column2 [5, 10] --> pruned (column1 out of range, column2 in of range) + // c4: column1 [1000, 2000], column2 [0, 4] --> not pruned (column1 in range, column2 in range) + // c5: column1 [0, 10], column2 Null --> pruned (column1 out of range, but column2 has no stats) + // c6: column1 Null, column2 [0, 4] --> not pruned (column1 has no stats, column2 out of range) + + let c1 = Arc::new( + TestChunk::new("chunk1") + .with_i64_field_column_with_stats("column1", Some(0), Some(1000)) + .with_i64_field_column_with_stats("column2", Some(0), Some(4)), + ) as Arc; + + let c2 = Arc::new( + TestChunk::new("chunk2") + .with_i64_field_column_with_stats("column1", Some(0), Some(10)) + .with_i64_field_column_with_stats("column2", Some(0), Some(4)), + ) as Arc; + + let c3 = Arc::new( + TestChunk::new("chunk3") + .with_i64_field_column_with_stats("column1", Some(0), Some(10)) + .with_i64_field_column_with_stats("column2", Some(5), Some(10)), + ) as Arc; + + let c4 = Arc::new( + TestChunk::new("chunk4") + .with_i64_field_column_with_stats("column1", Some(1000), Some(2000)) + .with_i64_field_column_with_stats("column2", Some(0), Some(4)), + ) as Arc; + + let c5 = Arc::new( + TestChunk::new("chunk5") + .with_i64_field_column_with_stats("column1", Some(0), Some(10)) + .with_i64_field_column("column2"), + ) as Arc; + + let c6 = Arc::new( + TestChunk::new("chunk6") + .with_i64_field_column("column1") + .with_i64_field_column_with_stats("column2", Some(0), Some(4)), + ) as Arc; + + let filters = vec![col("column1") + .gt(lit(100i64)) + .and(col("column2").lt(lit(5i64)))]; + + let chunks = vec![c1, c2, c3, c4, c5, c6]; + let schema = merge_schema(&chunks); + + let result = prune_chunks(&schema, &chunks, &filters); + + assert_eq!( + result.expect("Pruning succeeds"), + vec![true, false, false, true, false, true] + ); + } +} diff --git a/iox_query/src/query_log.rs b/iox_query/src/query_log.rs new file mode 100644 index 0000000..e6ae929 --- /dev/null +++ b/iox_query/src/query_log.rs @@ -0,0 +1,704 @@ +//! Ring buffer of queries that have been run with some brief information + +use data_types::NamespaceId; +use datafusion::physical_plan::ExecutionPlan; +use iox_time::{Time, TimeProvider}; +use observability_deps::tracing::{info, warn}; +use parking_lot::Mutex; +use std::{ + collections::VecDeque, + fmt::Debug, + sync::{ + atomic::{self, AtomicBool, AtomicI64, AtomicUsize, Ordering}, + Arc, + }, + time::Duration, +}; +use trace::ctx::TraceId; +use uuid::Uuid; + +/// The query duration used for queries still running. +const UNCOMPLETED_DURATION: i64 = -1; + +/// Information about a single query that was executed +pub struct QueryLogEntry { + /// Unique ID. + pub id: Uuid, + + /// Namespace ID. + pub namespace_id: NamespaceId, + + /// Namespace name. + pub namespace_name: Arc, + + /// The type of query + pub query_type: &'static str, + + /// The text of the query (SQL for sql queries, pbjson for storage rpc queries) + pub query_text: QueryText, + + /// The trace ID if any + pub trace_id: Option, + + /// Time at which the query was run + pub issue_time: Time, + + /// Duration it took to acquire a semaphore permit, relative to [`issue_time`](Self::issue_time). + permit_duration: AtomicDuration, + + /// Duration it took to plan the query, relative to [`issue_time`](Self::issue_time) + [`permit_duration`](Self::permit_duration). + plan_duration: AtomicDuration, + + /// Duration it took to execute the query, relative to [`issue_time`](Self::issue_time) + + /// [`permit_duration`](Self::permit_duration) + [`plan_duration`](Self::plan_duration). + execute_duration: AtomicDuration, + + /// Duration from [`issue_time`](Self::issue_time) til the query ended somehow. + end2end_duration: AtomicDuration, + + /// CPU duration spend for computation. + compute_duration: AtomicDuration, + + /// If the query completed successfully + success: AtomicBool, + + /// If the query is currently running (in any state). + running: AtomicBool, +} + +impl Debug for QueryLogEntry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("QueryLogEntry") + .field("id", &self.id) + .field("namespace_id", &self.namespace_id) + .field("namespace_name", &self.namespace_name) + .field("query_type", &self.query_type) + .field("query_text", &self.query_text.to_string()) + .field("trace_id", &self.trace_id) + .field("issue_time", &self.issue_time) + .field("permit_duration", &self.permit_duration()) + .field("plan_duration", &self.plan_duration()) + .field("execute_duration", &self.execute_duration()) + .field("end2end_duration", &self.end2end_duration()) + .field("compute_duration", &self.compute_duration()) + .field("success", &self.success()) + .field("running", &self.running()) + .finish() + } +} + +impl QueryLogEntry { + /// Duration it took to acquire a semaphore permit, relative to [`issue_time`](Self::issue_time). + pub fn permit_duration(&self) -> Option { + self.permit_duration.get() + } + + /// Duration it took to plan the query, relative to [`issue_time`](Self::issue_time) + [`permit_duration`](Self::permit_duration). + pub fn plan_duration(&self) -> Option { + self.plan_duration.get() + } + + /// Duration it took to execute the query, relative to [`issue_time`](Self::issue_time) + + /// [`permit_duration`](Self::permit_duration) + [`plan_duration`](Self::plan_duration). + pub fn execute_duration(&self) -> Option { + self.execute_duration.get() + } + + /// Duration from [`issue_time`](Self::issue_time) til the query ended somehow. + pub fn end2end_duration(&self) -> Option { + self.end2end_duration.get() + } + + /// CPU duration spend for computation. + pub fn compute_duration(&self) -> Option { + self.compute_duration.get() + } + + /// Returns true if `set_completed` was called with `success=true` + pub fn success(&self) -> bool { + self.success.load(Ordering::SeqCst) + } + + /// If the query is currently running (in any state). + pub fn running(&self) -> bool { + self.running.load(Ordering::SeqCst) + } + + /// Log entry. + pub fn log(&self, when: &'static str) { + info!( + when, + id=%self.id, + namespace_id=self.namespace_id.get(), + namespace_name=self.namespace_name.as_ref(), + query_type=self.query_type, + query_text=%self.query_text, + trace_id=self.trace_id.map(|id| format!("{:x}", id.get())), + issue_time=%self.issue_time, + plan_duration_secs=self.plan_duration().map(|d| d.as_secs_f64()), + permit_duration_secs=self.permit_duration().map(|d| d.as_secs_f64()), + execute_duration_secs=self.execute_duration().map(|d| d.as_secs_f64()), + end2end_duration_secs=self.end2end_duration().map(|d| d.as_secs_f64()), + compute_duration_secs=self.compute_duration().map(|d| d.as_secs_f64()), + success=self.success(), + running=self.running(), + "query", + ) + } +} + +/// Snapshot of the entries the [`QueryLog`]. +#[derive(Debug)] +pub struct QueryLogEntries { + /// Entries. + pub entries: VecDeque>, + + /// Maximum number of entries + pub max_size: usize, + + /// Number of evicted entries due to the "max size" constraint. + pub evicted: usize, +} + +/// Stores a fixed number `QueryExecutions` -- handles locking +/// internally so can be shared across multiple +pub struct QueryLog { + log: Mutex>>, + max_size: usize, + evicted: AtomicUsize, + time_provider: Arc, + id_gen: IDGen, +} + +impl QueryLog { + /// Create a new QueryLog that can hold at most `size` items. + /// When the `size+1` item is added, item `0` is evicted. + pub fn new(max_size: usize, time_provider: Arc) -> Self { + Self::new_with_id_gen(max_size, time_provider, Box::new(Uuid::new_v4)) + } + + pub fn new_with_id_gen( + max_size: usize, + time_provider: Arc, + id_gen: IDGen, + ) -> Self { + Self { + log: Mutex::new(VecDeque::with_capacity(max_size)), + max_size, + evicted: AtomicUsize::new(0), + time_provider, + id_gen, + } + } + + pub fn push( + &self, + namespace_id: NamespaceId, + namespace_name: Arc, + query_type: &'static str, + query_text: QueryText, + trace_id: Option, + ) -> QueryCompletedToken { + let entry = Arc::new(QueryLogEntry { + id: (self.id_gen)(), + namespace_id, + namespace_name, + query_type, + query_text, + trace_id, + issue_time: self.time_provider.now(), + permit_duration: Default::default(), + plan_duration: Default::default(), + execute_duration: Default::default(), + end2end_duration: Default::default(), + compute_duration: Default::default(), + success: atomic::AtomicBool::new(false), + running: atomic::AtomicBool::new(true), + }); + entry.log("start"); + let token = QueryCompletedToken { + entry: Some(Arc::clone(&entry)), + time_provider: Arc::clone(&self.time_provider), + state: Default::default(), + }; + + if self.max_size == 0 { + return token; + } + + let mut log = self.log.lock(); + + // enforce limit + while log.len() > self.max_size { + log.pop_front(); + self.evicted.fetch_add(1, Ordering::SeqCst); + } + + log.push_back(Arc::clone(&entry)); + token + } + + pub fn entries(&self) -> QueryLogEntries { + let log = self.log.lock(); + QueryLogEntries { + entries: log.clone(), + max_size: self.max_size, + evicted: self.evicted.load(Ordering::SeqCst), + } + } +} + +impl Debug for QueryLog { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("QueryLog") + .field("log", &self.log) + .field("max_size", &self.max_size) + .field("evicted", &self.evicted) + .field("time_provider", &self.time_provider) + .field("id_gen", &"") + .finish() + } +} + +/// State of [`QueryCompletedToken`]. +/// +/// # Done +/// - The query has been received (and potentially authenticated) by the server. +/// +/// # To Do +/// - The concurrency-limiting semaphore has NOT yet issued a permit. +/// - The query is not planned. +/// - The query has not been executed. +#[derive(Debug, Clone, Copy, Default)] +pub struct StateReceived; + +/// State of [`QueryCompletedToken`]. +/// +/// # Done +/// - The query has been received (and potentially authenticated) by the server. +/// - The concurrency-limiting semaphore has issued a permit. +/// - The query was planned. +/// +/// # To Do +/// - The concurrency-limiting semaphore has NOT yet issued a permit. +/// - The query has not been executed. +#[derive(Debug)] +pub struct StatePlanned { + /// Physical execution plan. + plan: Arc, +} + +/// State of [`QueryCompletedToken`]. +/// +/// # Done +/// - The query has been received (and potentially authenticated) by the server. +/// - The concurrency-limiting semaphore has issued a permit. +/// +/// # To Do +/// - The query has not been executed. +#[derive(Debug)] +pub struct StatePermit { + /// Physical execution plan. + plan: Arc, +} + +/// A `QueryCompletedToken` is returned by `record_query` implementations of +/// a `QueryNamespace`. It is used to trigger side-effects (such as query timing) +/// on query completion. +#[derive(Debug)] +pub struct QueryCompletedToken { + /// Entry. + /// + /// This is optional so we can implement type state and [`Drop`] at the same time. + entry: Option>, + + /// Time provider + time_provider: Arc, + + /// Current state. + state: S, +} + +impl QueryCompletedToken { + /// Underlying entry. + pub fn entry(&self) -> &Arc { + self.entry.as_ref().expect("valid state") + } +} + +impl QueryCompletedToken { + /// Record that this query got planned. + pub fn planned(mut self, plan: Arc) -> QueryCompletedToken { + let entry = self.entry.take().expect("valid state"); + + let now = self.time_provider.now(); + let origin = entry.issue_time; + entry.plan_duration.set_relative(origin, now); + + QueryCompletedToken { + entry: Some(entry), + time_provider: Arc::clone(&self.time_provider), + state: StatePlanned { plan }, + } + } +} + +impl QueryCompletedToken { + /// Record that this query got a semaphore permit. + pub fn permit(mut self) -> QueryCompletedToken { + let entry = self.entry.take().expect("valid state"); + + let now = self.time_provider.now(); + let origin = entry.issue_time + entry.plan_duration().expect("valid state"); + entry.permit_duration.set_relative(origin, now); + + QueryCompletedToken { + entry: Some(entry), + time_provider: Arc::clone(&self.time_provider), + state: StatePermit { + plan: Arc::clone(&self.state.plan), + }, + } + } +} + +impl QueryCompletedToken { + /// Record that this query completed successfully + pub fn success(self) { + let entry = self.entry.as_ref().expect("valid state"); + entry.success.store(true, Ordering::SeqCst); + + self.finish() + } + + /// Record that the query finished execution with an error. + pub fn fail(self) { + self.finish() + } + + fn finish(&self) { + let entry = self.entry.as_ref().expect("valid state"); + + let now = self.time_provider.now(); + let origin = entry.issue_time + + entry.permit_duration().expect("valid state") + + entry.plan_duration().expect("valid state"); + entry.execute_duration.set_relative(origin, now); + + entry + .compute_duration + .set_absolute(collect_compute_duration(self.state.plan.as_ref())); + } +} + +impl Drop for QueryCompletedToken { + fn drop(&mut self) { + if let Some(entry) = self.entry.take() { + let now = self.time_provider.now(); + entry.end2end_duration.set_relative(entry.issue_time, now); + entry.running.store(false, Ordering::SeqCst); + + entry.log("end"); + } + } +} + +/// Boxed description of a query that knows how to render to a string +/// +/// This avoids storing potentially large strings +pub type QueryText = Box; + +/// Method that generated [`Uuid`]s. +pub type IDGen = Box Uuid + Send + Sync>; + +struct AtomicDuration(AtomicI64); + +impl AtomicDuration { + fn get(&self) -> Option { + match self.0.load(Ordering::Relaxed) { + UNCOMPLETED_DURATION => None, + d => Some(Duration::from_nanos(d as u64)), + } + } + + fn set_relative(&self, origin: Time, now: Time) { + match now.checked_duration_since(origin) { + Some(dur) => { + self.0.store(dur.as_nanos() as i64, Ordering::Relaxed); + } + None => { + warn!("Clock went backwards, not query duration") + } + } + } + + fn set_absolute(&self, d: Duration) { + self.0.store(d.as_nanos() as i64, Ordering::Relaxed); + } +} + +impl Default for AtomicDuration { + fn default() -> Self { + Self(AtomicI64::new(UNCOMPLETED_DURATION)) + } +} + +/// Collect compute duration from [`ExecutionPlan`]. +fn collect_compute_duration(plan: &dyn ExecutionPlan) -> Duration { + let mut total = Duration::ZERO; + + if let Some(metrics) = plan.metrics() { + if let Some(nanos) = metrics.elapsed_compute() { + total += Duration::from_nanos(nanos as u64); + } + } + + for child in plan.children() { + total += collect_compute_duration(child.as_ref()); + } + + total +} + +#[cfg(test)] +mod test_super { + use datafusion::error::DataFusionError; + use std::sync::atomic::AtomicU64; + + use datafusion::physical_plan::{ + metrics::{MetricValue, MetricsSet}, + DisplayAs, Metric, + }; + use iox_time::MockProvider; + use test_helpers::tracing::TracingCapture; + + use super::*; + + #[test] + fn test_token_end2end_success() { + let capture = TracingCapture::new(); + + let Test { + time_provider, + token, + entry, + } = Test::default(); + + assert!(!entry.success()); + assert!(entry.running()); + assert_eq!(entry.permit_duration(), None,); + assert_eq!(entry.plan_duration(), None,); + assert_eq!(entry.execute_duration(), None,); + assert_eq!(entry.end2end_duration(), None,); + assert_eq!(entry.compute_duration(), None,); + + time_provider.inc(Duration::from_millis(1)); + let token = token.planned(plan()); + + assert!(!entry.success()); + assert!(entry.running()); + assert_eq!(entry.plan_duration(), Some(Duration::from_millis(1)),); + assert_eq!(entry.permit_duration(), None,); + assert_eq!(entry.execute_duration(), None,); + assert_eq!(entry.end2end_duration(), None,); + assert_eq!(entry.compute_duration(), None,); + + time_provider.inc(Duration::from_millis(10)); + let token = token.permit(); + + assert!(!entry.success()); + assert!(entry.running()); + assert_eq!(entry.plan_duration(), Some(Duration::from_millis(1)),); + assert_eq!(entry.permit_duration(), Some(Duration::from_millis(10)),); + assert_eq!(entry.execute_duration(), None,); + assert_eq!(entry.end2end_duration(), None,); + assert_eq!(entry.compute_duration(), None,); + + time_provider.inc(Duration::from_millis(100)); + token.success(); + + assert!(entry.success()); + assert!(!entry.running()); + assert_eq!(entry.plan_duration(), Some(Duration::from_millis(1)),); + assert_eq!(entry.permit_duration(), Some(Duration::from_millis(10)),); + assert_eq!(entry.execute_duration(), Some(Duration::from_millis(100)),); + assert_eq!(entry.end2end_duration(), Some(Duration::from_millis(111)),); + assert_eq!(entry.compute_duration(), Some(Duration::from_millis(1_337)),); + + assert_eq!( + capture.to_string().trim(), + [ + r#"level = INFO; message = query; when = "start"; id = 00000000-0000-0000-0000-000000000001; namespace_id = 1; namespace_name = "ns"; query_type = "sql"; query_text = SELECT 1; issue_time = 1970-01-01T00:00:00.100+00:00; success = false; running = true;"#, + r#"level = INFO; message = query; when = "end"; id = 00000000-0000-0000-0000-000000000001; namespace_id = 1; namespace_name = "ns"; query_type = "sql"; query_text = SELECT 1; issue_time = 1970-01-01T00:00:00.100+00:00; plan_duration_secs = 0.001; permit_duration_secs = 0.01; execute_duration_secs = 0.1; end2end_duration_secs = 0.111; compute_duration_secs = 1.337; success = true; running = false;"#, + ].join(" \n") + ); + } + + #[test] + fn test_token_execution_fail() { + let capture = TracingCapture::new(); + + let Test { + time_provider, + token, + entry, + } = Test::default(); + + time_provider.inc(Duration::from_millis(1)); + let token = token.planned(plan()); + time_provider.inc(Duration::from_millis(10)); + let token = token.permit(); + time_provider.inc(Duration::from_millis(100)); + token.fail(); + + assert!(!entry.success()); + assert!(!entry.running()); + assert_eq!(entry.plan_duration(), Some(Duration::from_millis(1)),); + assert_eq!(entry.permit_duration(), Some(Duration::from_millis(10)),); + assert_eq!(entry.execute_duration(), Some(Duration::from_millis(100)),); + assert_eq!(entry.end2end_duration(), Some(Duration::from_millis(111)),); + assert_eq!(entry.compute_duration(), Some(Duration::from_millis(1_337)),); + + assert_eq!( + capture.to_string().trim(), + [ + r#"level = INFO; message = query; when = "start"; id = 00000000-0000-0000-0000-000000000001; namespace_id = 1; namespace_name = "ns"; query_type = "sql"; query_text = SELECT 1; issue_time = 1970-01-01T00:00:00.100+00:00; success = false; running = true;"#, + r#"level = INFO; message = query; when = "end"; id = 00000000-0000-0000-0000-000000000001; namespace_id = 1; namespace_name = "ns"; query_type = "sql"; query_text = SELECT 1; issue_time = 1970-01-01T00:00:00.100+00:00; plan_duration_secs = 0.001; permit_duration_secs = 0.01; execute_duration_secs = 0.1; end2end_duration_secs = 0.111; compute_duration_secs = 1.337; success = false; running = false;"#, + ].join(" \n") + ); + } + + #[test] + fn test_token_drop_before_acquire() { + let capture = TracingCapture::new(); + + let Test { + time_provider, + token, + entry, + } = Test::default(); + + time_provider.inc(Duration::from_millis(100)); + drop(token); + + assert!(!entry.success()); + assert!(!entry.running()); + assert_eq!(entry.permit_duration(), None,); + assert_eq!(entry.plan_duration(), None,); + assert_eq!(entry.execute_duration(), None,); + assert_eq!(entry.end2end_duration(), Some(Duration::from_millis(100)),); + assert_eq!(entry.compute_duration(), None,); + + assert_eq!( + capture.to_string().trim(), + [ + r#"level = INFO; message = query; when = "start"; id = 00000000-0000-0000-0000-000000000001; namespace_id = 1; namespace_name = "ns"; query_type = "sql"; query_text = SELECT 1; issue_time = 1970-01-01T00:00:00.100+00:00; success = false; running = true;"#, + r#"level = INFO; message = query; when = "end"; id = 00000000-0000-0000-0000-000000000001; namespace_id = 1; namespace_name = "ns"; query_type = "sql"; query_text = SELECT 1; issue_time = 1970-01-01T00:00:00.100+00:00; end2end_duration_secs = 0.1; success = false; running = false;"#, + ].join(" \n") + ); + } + + struct Test { + time_provider: Arc, + token: QueryCompletedToken, + entry: Arc, + } + + impl Default for Test { + fn default() -> Self { + let time_provider = + Arc::new(MockProvider::new(Time::from_timestamp_millis(100).unwrap())); + let id_counter = AtomicU64::new(1); + let log = QueryLog::new_with_id_gen( + 1_000, + Arc::clone(&time_provider) as _, + Box::new(move || Uuid::from_u128(id_counter.fetch_add(1, Ordering::SeqCst) as _)), + ); + + let token = log.push( + NamespaceId::new(1), + Arc::from("ns"), + "sql", + Box::new("SELECT 1"), + None, + ); + + let entry = Arc::clone(token.entry()); + + Self { + time_provider, + token, + entry, + } + } + } + + fn plan() -> Arc { + Arc::new(TestExec) + } + + #[derive(Debug)] + struct TestExec; + + impl DisplayAs for TestExec { + fn fmt_as( + &self, + _t: datafusion::physical_plan::DisplayFormatType, + _f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + unimplemented!() + } + } + + impl ExecutionPlan for TestExec { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> arrow::datatypes::SchemaRef { + unimplemented!() + } + + fn output_partitioning(&self) -> datafusion::physical_plan::Partitioning { + unimplemented!() + } + + fn output_ordering(&self) -> Option<&[datafusion::physical_expr::PhysicalSortExpr]> { + unimplemented!() + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> datafusion::error::Result> { + unimplemented!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> datafusion::error::Result + { + unimplemented!() + } + + fn statistics(&self) -> Result { + unimplemented!() + } + + fn metrics(&self) -> Option { + let mut metrics = MetricsSet::default(); + + let t = datafusion::physical_plan::metrics::Time::default(); + t.add_duration(Duration::from_millis(1_337)); + metrics.push(Arc::new(Metric::new(MetricValue::ElapsedCompute(t), None))); + + Some(metrics) + } + } +} diff --git a/iox_query/src/statistics.rs b/iox_query/src/statistics.rs new file mode 100644 index 0000000..3fc4d54 --- /dev/null +++ b/iox_query/src/statistics.rs @@ -0,0 +1,1447 @@ +//! Code to translate IOx statistics to DataFusion statistics + +use std::collections::{HashMap, VecDeque}; +use std::sync::Arc; + +use arrow::compute::rank; +use arrow::datatypes::{Schema, SchemaRef}; +use datafusion::common::stats::Precision; +use datafusion::datasource::physical_plan::ParquetExec; +use datafusion::error::DataFusionError; +use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion::physical_plan::empty::EmptyExec; +use datafusion::physical_plan::expressions::Column; +use datafusion::physical_plan::filter::FilterExec; +use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; +use datafusion::physical_plan::projection::ProjectionExec; +use datafusion::physical_plan::repartition::RepartitionExec; +use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use datafusion::physical_plan::union::UnionExec; +use datafusion::physical_plan::{visit_execution_plan, ExecutionPlan, ExecutionPlanVisitor}; +use datafusion::{ + physical_plan::{ColumnStatistics, Statistics as DFStatistics}, + scalar::ScalarValue, +}; +use observability_deps::tracing::trace; + +use crate::provider::{DeduplicateExec, RecordBatchesExec}; +use crate::{QueryChunk, CHUNK_ORDER_COLUMN_NAME}; + +/// Aggregates DataFusion [statistics](DFStatistics). +#[derive(Debug)] +pub struct DFStatsAggregator<'a> { + num_rows: Precision, + total_byte_size: Precision, + column_statistics: Vec, + // Maps column name to index in column_statistics for all columns we are + // aggregating + col_idx_map: HashMap<&'a str, usize>, +} + +impl<'a> DFStatsAggregator<'a> { + /// Creates new aggregator the the given schema. + /// + /// This will start with: + /// + /// - 0 rows + /// - 0 bytes + /// - for each column: + /// - 0 null values + /// - unknown min value + /// - unknown max value + /// - exact representation + pub fn new(schema: &'a Schema) -> Self { + let col_idx_map = schema + .fields() + .iter() + .enumerate() + .map(|(idx, f)| (f.name().as_str(), idx)) + .collect::>(); + + Self { + num_rows: Precision::Exact(0), + total_byte_size: Precision::Exact(0), + column_statistics: (0..col_idx_map.len()) + .map(|_| DFStatsAggregatorCol { + null_count: Precision::Exact(0), + max_value: None, + min_value: None, + }) + .collect(), + + col_idx_map, + } + } + + /// Update given base statistics with the given schema. + /// + /// This only updates columns that were present when the aggregator was created. Column reordering is allowed. + /// + /// Updates are meant to be "additive", i.e. they only add data/rows. There is NOT way to remove/substract data from + /// the accumulator. + /// + /// # Panics + /// Panics when the number of columns in the statistics and the schema are different. + pub fn update(&mut self, update_stats: &DFStatistics, update_schema: &Schema) { + // decompose structs so we don't forget new fields + let DFStatistics { + num_rows: update_num_rows, + total_byte_size: update_total_byte_size, + column_statistics: update_column_statistics, + } = update_stats; + + self.num_rows = self.num_rows.add(update_num_rows); + self.total_byte_size = self.total_byte_size.add(update_total_byte_size); + + assert_eq!(self.column_statistics.len(), self.col_idx_map.len()); + assert_eq!( + update_column_statistics.len(), + update_schema.fields().len(), + "stats ({}) and schema ({}) have different column count", + update_column_statistics.len(), + update_schema.fields().len(), + ); + + let mut used_cols = vec![false; self.col_idx_map.len()]; + + for (update_field, update_col) in update_schema + .fields() + .iter() + .zip(update_column_statistics.iter()) + { + // Skip if not aggregating statitics for this field + let Some(idx) = self.col_idx_map.get(update_field.name().as_str()) else { + continue; + }; + let base_col = &mut self.column_statistics[*idx]; + used_cols[*idx] = true; + + // decompose structs so we don't forget new fields + let DFStatsAggregatorCol { + null_count: base_null_count, + max_value: base_max_value, + min_value: base_min_value, + } = base_col; + let ColumnStatistics { + null_count: update_null_count, + max_value: update_max_value, + min_value: update_min_value, + distinct_count: _update_distinct_count, + } = update_col; + + *base_null_count = base_null_count.add(update_null_count); + + *base_max_value = Some( + base_max_value + .take() + .map(|base_max_value| base_max_value.max(update_max_value)) + .unwrap_or(update_max_value.clone()), + ); + + *base_min_value = Some( + base_min_value + .take() + .map(|base_min_value| base_min_value.min(update_min_value)) + .unwrap_or(update_min_value.clone()), + ); + } + + // for unused cols, we need to assume all-NULL and hence invalidate the null counters + for (used, base_col) in used_cols.into_iter().zip(&mut self.column_statistics) { + if !used { + base_col.null_count = Precision::Absent; + } + } + } + + /// Build aggregated statistics. + pub fn build(self) -> DFStatistics { + DFStatistics { + num_rows: self.num_rows, + total_byte_size: self.total_byte_size, + column_statistics: self + .column_statistics + .into_iter() + .map(|col| ColumnStatistics { + null_count: col.null_count, + max_value: col.max_value.unwrap_or(Precision::Absent), + min_value: col.min_value.unwrap_or(Precision::Absent), + distinct_count: Precision::Absent, + }) + .collect(), + } + } +} + +/// Similar to [`ColumnStatistics`] but uses `Option` to track min/max values so +/// we can differentiate between +/// +/// 1. "uninitialized" (`None`) +/// 1. "initialized" (`Some(Precision::Exact(...))`) +/// 2. "initialized but invalid" (`Some(Precision::Absent)`). +/// +/// It also does NOT contain a distinct count because we cannot aggregate these. +#[derive(Debug)] +struct DFStatsAggregatorCol { + null_count: Precision, + max_value: Option>, + min_value: Option>, +} + +/// build DF statitics for given chunks and a schema +pub fn build_statistics_for_chunks( + chunks: &[Arc], + schema: SchemaRef, +) -> DFStatistics { + let chunk_order_field = schema.field_with_name(CHUNK_ORDER_COLUMN_NAME).ok(); + let chunk_order_only_schema = chunk_order_field.map(|field| Schema::new(vec![field.clone()])); + + let chunks: Vec<_> = chunks.iter().collect(); + + let statistics = chunks + .iter() + .fold(DFStatsAggregator::new(&schema), |mut agg, chunk| { + agg.update(&chunk.stats(), chunk.schema().as_arrow().as_ref()); + + if let Some(schema) = chunk_order_only_schema.as_ref() { + let order = chunk.order().get(); + let order = ScalarValue::from(order); + + agg.update( + &DFStatistics { + num_rows: Precision::Exact(0), + total_byte_size: Precision::Exact(0), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(order.clone()), + min_value: Precision::Exact(order), + distinct_count: Precision::Exact(1), + }], + }, + schema, + ); + } + + agg + }) + .build(); + + statistics +} + +/// Traverse the execution plan and build statistics min max for the given column +pub fn compute_stats_column_min_max( + plan: &dyn ExecutionPlan, + column_name: &str, +) -> Result { + let mut visitor = StatisticsVisitor::new(column_name); + visit_execution_plan(plan, &mut visitor)?; + + // there must be only one statistics left in the stack + if visitor.statistics.len() != 1 { + return Err(DataFusionError::Internal(format!( + "There must be only one statistics left in the stack, but find {}", + visitor.statistics.len() + ))); + } + + Ok(visitor.statistics.pop_back().unwrap()) +} + +/// Traverse the physical plan and build statistics min max for the given column each node +/// Note: This is a temproray solution until DF's statistics is more mature +/// +struct StatisticsVisitor<'a> { + column_name: &'a str, //String, // todo: not sure enough + statistics: VecDeque, +} + +impl<'a> StatisticsVisitor<'a> { + fn new(column_name: &'a str) -> Self { + Self { + column_name, + statistics: VecDeque::new(), + } + } +} + +impl ExecutionPlanVisitor for StatisticsVisitor<'_> { + type Error = DataFusionError; + + fn pre_visit(&mut self, _plan: &dyn ExecutionPlan) -> Result { + Ok(false) + } + + fn post_visit(&mut self, plan: &dyn ExecutionPlan) -> Result { + // If this is an EmptyExec / PlaceholderRowExec, we don't know about it + if plan.as_any().downcast_ref::().is_some() + || plan.as_any().downcast_ref::().is_some() + { + self.statistics.push_back(ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + distinct_count: Precision::Absent, + }); + } + // If this is leaf node (ParquetExec or RecordBatchExec), compute its statistics and push it to the stack + else if plan.as_any().downcast_ref::().is_some() + || plan.as_any().downcast_ref::().is_some() + { + // get index of the given column in the schema + let statistics = match plan.schema().index_of(self.column_name) { + Ok(col_index) => plan.statistics()?.column_statistics[col_index].clone(), + // This is the case of alias, do not optimize by returning no statistics + Err(_) => { + trace!( + " ------------------- No statistics for column {} in PQ/RB", + self.column_name + ); + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + distinct_count: Precision::Absent, + } + } + }; + self.statistics.push_back(statistics); + } + // Non leaf node + else { + // These are cases the stats will be unioned of their children's + // Sort, Dediplicate, Filter, Repartition, Union, SortPreservingMerge, CoalesceBatches + let union_stats = if plan.as_any().downcast_ref::().is_some() + || plan.as_any().downcast_ref::().is_some() + || plan.as_any().downcast_ref::().is_some() + || plan.as_any().downcast_ref::().is_some() + || plan.as_any().downcast_ref::().is_some() + || plan + .as_any() + .downcast_ref::() + .is_some() + || plan + .as_any() + .downcast_ref::() + .is_some() + { + true + } else if plan.as_any().downcast_ref::().is_some() { + // ProjectionExec is a special case. Only union stats if it includes pure columns + projection_includes_pure_columns( + plan.as_any().downcast_ref::().unwrap(), + ) + } else { + false + }; + + // pop statistics of all inputs from the stack + let num_inputs = plan.children().len(); + // num_input must > 0. Pop the first one + let mut statistics = self + .statistics + .pop_back() + .expect("No statistics for input plan"); + // pop the rest and update the min and max + for _ in 1..num_inputs { + let input_statistics = self + .statistics + .pop_back() + .expect("No statistics for input plan"); + + if union_stats { + // Convervatively union min max + statistics.null_count = statistics.null_count.add(&input_statistics.null_count); + statistics.max_value = statistics.max_value.max(&input_statistics.max_value); + statistics.min_value = statistics.min_value.min(&input_statistics.min_value); + statistics.distinct_count = Precision::Absent; + }; + } + + if union_stats { + self.statistics.push_back(statistics); + } else { + trace!( + " ------ No statistics for column {} in non-leaf node", + self.column_name + ); + // Make them absent for other cases + self.statistics.push_back(ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + distinct_count: Precision::Absent, + }); + } + } + + Ok(true) + } +} + +fn projection_includes_pure_columns(projection: &ProjectionExec) -> bool { + projection + .expr() + .iter() + .all(|(expr, _col_name)| expr.as_any().downcast_ref::().is_some()) +} + +/// Return min max of a ColumnStatistics with precise values +pub fn column_statistics_min_max( + column_statistics: &ColumnStatistics, +) -> Option<(ScalarValue, ScalarValue)> { + match (&column_statistics.min_value, &column_statistics.max_value) { + (Precision::Exact(min), Precision::Exact(max)) => Some((min.clone(), max.clone())), + // the statistics values are absent or imprecise + _ => None, + } +} + +/// Get statsistics min max of given column name on given plans +/// Return None if one of the inputs does not have statistics or does not include the column +pub fn statistics_min_max( + plans: &[Arc], + column_name: &str, +) -> Option> { + // Get statistics for each plan + let plans_schema_and_stats = plans + .iter() + .map(|plan| Ok((Arc::clone(plan), plan.schema(), plan.statistics()?))) + .collect::, DataFusionError>>(); + + // If any without statistics, return none + let Ok(plans_schema_and_stats) = plans_schema_and_stats else { + return None; + }; + + // get value range of the sorted column for each input + let mut min_max_ranges = Vec::with_capacity(plans_schema_and_stats.len()); + for (input, input_schema, input_stats) in plans_schema_and_stats { + // get index of the sorted column in the schema + let Ok(sorted_col_index) = input_schema.index_of(column_name) else { + // panic that the sorted column is not in the schema + panic!("sorted column {} is not in the schema", column_name); + }; + + let column_stats = input_stats.column_statistics; + let sorted_col_stats = column_stats[sorted_col_index].clone(); + match (sorted_col_stats.min_value, sorted_col_stats.max_value) { + (Precision::Exact(min), Precision::Exact(max)) => { + min_max_ranges.push((min, max)); + } + // WARNING: this may produce incorrect results until we use more precision + // as `Inexact` is not guaranteed to cover the actual min and max values + // https://github.com/apache/arrow-datafusion/issues/8078 + (Precision::Inexact(min), Precision::Inexact(max)) => { + if let Some(_deduplicate_exec) = input.as_any().downcast_ref::() { + min_max_ranges.push((min, max)); + } else { + return None; + }; + } + // the statistics values are absent + _ => return None, + } + } + + Some(min_max_ranges) +} + +/// Return true if at least 2 min_max ranges in the given array overlap +pub fn overlap(value_ranges: &[(ScalarValue, ScalarValue)]) -> Result { + // interleave min and max into one iterator + let value_ranges_iter = value_ranges.iter().flat_map(|(min, max)| { + // panics if min > max + if min > max { + panic!("min ({:?}) > max ({:?})", min, max); + } + vec![min.clone(), max.clone()] + }); + + let value_ranges = ScalarValue::iter_to_array(value_ranges_iter)?; + + // rank it + let ranks = rank(&*value_ranges, None)?; + + // check overlap by checking if the max is rank right behind its corresponding min + // . non-overlap example: values of min-max pairs [3, 5, 9, 12, 1, 1, 6, 8] + // ranks: [3, 4, 7, 8, 2, 2, 5, 6] : max (even index) = its correspnding min (odd index) for same min max OR min + 1 + // . overlap example: [3, 5, 9, 12, 1, 1, 4, 6] : pair [3, 5] interleaves with pair [4, 6] + // ranks: [3, 5, 7, 8, 2, 2, 4, 6] + for i in (0..ranks.len()).step_by(2) { + if !((ranks[i] == ranks[i + 1]) || (ranks[i + 1] == ranks[i] + 1)) { + return Ok(true); + } + } + + Ok(false) +} + +#[cfg(test)] +mod test { + use crate::{ + provider::chunks_to_physical_nodes, + test::{format_execution_plan, TestChunk}, + }; + + use super::*; + use arrow::datatypes::{DataType, Field}; + use datafusion::{common::Statistics, error::DataFusionError}; + use itertools::Itertools; + use schema::{InfluxFieldType, SchemaBuilder}; + + #[test] + fn test_df_stats_agg_no_cols_no_updates() { + let schema = Schema::new(Vec::::new()); + let agg = DFStatsAggregator::new(&schema); + + let actual = agg.build(); + let expected = DFStatistics { + num_rows: Precision::Exact(0), + total_byte_size: Precision::Exact(0), + column_statistics: Statistics::unknown_column(&schema), + }; + assert_eq!(actual, expected); + } + + #[test] + fn test_df_stats_agg_no_updates() { + let schema = Schema::new(vec![ + Field::new("col1", DataType::UInt64, true), + Field::new("col2", DataType::Utf8, false), + ]); + let agg = DFStatsAggregator::new(&schema); + + let actual = agg.build(); + let expected = DFStatistics { + num_rows: Precision::Exact(0), + total_byte_size: Precision::Exact(0), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Absent, + min_value: Precision::Absent, + distinct_count: Precision::Absent, + }, + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Absent, + min_value: Precision::Absent, + distinct_count: Precision::Absent, + }, + ], + }; + assert_eq!(actual, expected); + } + + #[test] + fn test_df_stats_agg_valid_update_partial() { + let schema = Schema::new(vec![ + Field::new("col1", DataType::UInt64, true), + Field::new("col2", DataType::Utf8, false), + ]); + let mut agg = DFStatsAggregator::new(&schema); + + let update_schema = Schema::new(vec![ + Field::new("col1", DataType::UInt64, true), + Field::new("col2", DataType::Utf8, false), + ]); + let update_stats = DFStatistics { + num_rows: Precision::Exact(1), + total_byte_size: Precision::Exact(10), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(100), + max_value: Precision::Exact(ScalarValue::UInt64(Some(100))), + min_value: Precision::Exact(ScalarValue::UInt64(Some(50))), + distinct_count: Precision::Exact(42), + }, + ColumnStatistics { + null_count: Precision::Exact(1_000), + max_value: Precision::Exact(ScalarValue::Utf8(Some("e".to_owned()))), + min_value: Precision::Exact(ScalarValue::Utf8(Some("b".to_owned()))), + distinct_count: Precision::Exact(42), + }, + ], + }; + agg.update(&update_stats, &update_schema); + + let update_schema = Schema::new(vec![Field::new("col2", DataType::Utf8, false)]); + let update_stats = DFStatistics { + num_rows: Precision::Exact(10_000), + total_byte_size: Precision::Exact(100_000), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(1_000_000), + max_value: Precision::Exact(ScalarValue::Utf8(Some("g".to_owned()))), + min_value: Precision::Exact(ScalarValue::Utf8(Some("c".to_owned()))), + distinct_count: Precision::Exact(42), + }], + }; + agg.update(&update_stats, &update_schema); + + let actual = agg.build(); + let expected = DFStatistics { + num_rows: Precision::Exact(10_001), + total_byte_size: Precision::Exact(100_010), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::UInt64(Some(100))), + min_value: Precision::Exact(ScalarValue::UInt64(Some(50))), + distinct_count: Precision::Absent, + }, + ColumnStatistics { + null_count: Precision::Exact(1_001_000), + max_value: Precision::Exact(ScalarValue::Utf8(Some("g".to_owned()))), + min_value: Precision::Exact(ScalarValue::Utf8(Some("b".to_owned()))), + distinct_count: Precision::Absent, + }, + ], + }; + assert_eq!(actual, expected); + } + + #[test] + fn test_df_stats_agg_valid_update_col_reorder() { + let schema = Schema::new(vec![ + Field::new("col1", DataType::UInt64, true), + Field::new("col2", DataType::Utf8, false), + ]); + let mut agg = DFStatsAggregator::new(&schema); + + let update_schema = Schema::new(vec![ + Field::new("col1", DataType::UInt64, true), + Field::new("col2", DataType::Utf8, false), + ]); + let update_stats = DFStatistics { + num_rows: Precision::Exact(1), + total_byte_size: Precision::Exact(10), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(100), + max_value: Precision::Exact(ScalarValue::UInt64(Some(100))), + min_value: Precision::Exact(ScalarValue::UInt64(Some(50))), + distinct_count: Precision::Exact(42), + }, + ColumnStatistics { + null_count: Precision::Exact(1_000), + max_value: Precision::Exact(ScalarValue::Utf8(Some("e".to_owned()))), + min_value: Precision::Exact(ScalarValue::Utf8(Some("b".to_owned()))), + distinct_count: Precision::Exact(42), + }, + ], + }; + agg.update(&update_stats, &update_schema); + + let update_schema = Schema::new(vec![ + Field::new("col2", DataType::Utf8, false), + Field::new("col1", DataType::UInt64, true), + ]); + let update_stats = DFStatistics { + num_rows: Precision::Exact(10_000), + total_byte_size: Precision::Exact(100_000), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(1_000_000), + max_value: Precision::Exact(ScalarValue::Utf8(Some("g".to_owned()))), + min_value: Precision::Exact(ScalarValue::Utf8(Some("c".to_owned()))), + distinct_count: Precision::Exact(42), + }, + ColumnStatistics { + null_count: Precision::Exact(10_000_000), + max_value: Precision::Exact(ScalarValue::UInt64(Some(99))), + min_value: Precision::Exact(ScalarValue::UInt64(Some(40))), + distinct_count: Precision::Exact(42), + }, + ], + }; + agg.update(&update_stats, &update_schema); + + let actual = agg.build(); + let expected = DFStatistics { + num_rows: Precision::Exact(10_001), + total_byte_size: Precision::Exact(100_010), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(10_000_100), + max_value: Precision::Exact(ScalarValue::UInt64(Some(100))), + min_value: Precision::Exact(ScalarValue::UInt64(Some(40))), + distinct_count: Precision::Absent, + }, + ColumnStatistics { + null_count: Precision::Exact(1_001_000), + max_value: Precision::Exact(ScalarValue::Utf8(Some("g".to_owned()))), + min_value: Precision::Exact(ScalarValue::Utf8(Some("b".to_owned()))), + distinct_count: Precision::Absent, + }, + ], + }; + assert_eq!(actual, expected); + } + + #[test] + fn test_df_stats_agg_ignores_unknown_cols() { + let schema = Schema::new(vec![ + Field::new("col1", DataType::UInt64, true), + Field::new("col2", DataType::Utf8, false), + ]); + let mut agg = DFStatsAggregator::new(&schema); + + let update_schema = Schema::new(vec![ + Field::new("col1", DataType::UInt64, true), + Field::new("col3", DataType::Utf8, false), + ]); + let update_stats = DFStatistics { + num_rows: Precision::Exact(1), + total_byte_size: Precision::Exact(10), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(100), + max_value: Precision::Exact(ScalarValue::UInt64(Some(100))), + min_value: Precision::Exact(ScalarValue::UInt64(Some(50))), + distinct_count: Precision::Exact(42), + }, + ColumnStatistics { + null_count: Precision::Exact(1_000), + max_value: Precision::Exact(ScalarValue::Utf8(Some("e".to_owned()))), + min_value: Precision::Exact(ScalarValue::Utf8(Some("b".to_owned()))), + distinct_count: Precision::Exact(42), + }, + ], + }; + agg.update(&update_stats, &update_schema); + + let actual = agg.build(); + let expected = DFStatistics { + num_rows: Precision::Exact(1), + total_byte_size: Precision::Exact(10), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(100), + max_value: Precision::Exact(ScalarValue::UInt64(Some(100))), + min_value: Precision::Exact(ScalarValue::UInt64(Some(50))), + distinct_count: Precision::Absent, + }, + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + distinct_count: Precision::Absent, + }, + ], + }; + assert_eq!(actual, expected); + } + + #[test] + fn test_df_stats_agg_invalidation() { + let schema = Schema::new(vec![ + Field::new("col1", DataType::UInt64, true), + Field::new("col2", DataType::Utf8, false), + ]); + + let update_stats = DFStatistics { + num_rows: Precision::Exact(1), + total_byte_size: Precision::Exact(10), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(100), + max_value: Precision::Exact(ScalarValue::UInt64(Some(100))), + min_value: Precision::Exact(ScalarValue::UInt64(Some(50))), + distinct_count: Precision::Exact(42), + }, + ColumnStatistics { + null_count: Precision::Exact(1_000), + max_value: Precision::Exact(ScalarValue::Utf8(Some("e".to_owned()))), + min_value: Precision::Exact(ScalarValue::Utf8(Some("b".to_owned()))), + distinct_count: Precision::Exact(42), + }, + ], + }; + let agg_stats = DFStatistics { + num_rows: Precision::Exact(2), + total_byte_size: Precision::Exact(20), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(200), + max_value: Precision::Exact(ScalarValue::UInt64(Some(100))), + min_value: Precision::Exact(ScalarValue::UInt64(Some(50))), + distinct_count: Precision::Absent, + }, + ColumnStatistics { + null_count: Precision::Exact(2_000), + max_value: Precision::Exact(ScalarValue::Utf8(Some("e".to_owned()))), + min_value: Precision::Exact(ScalarValue::Utf8(Some("b".to_owned()))), + distinct_count: Precision::Absent, + }, + ], + }; + + #[derive(Debug, Clone, Copy)] + enum ColMode { + NullCount, + MaxValue, + MinValue, + } + + #[derive(Debug, Clone, Copy)] + enum Mode { + NumRows, + TotalByteSize, + ColumnStatistics, + Col(usize, ColMode), + } + + impl Mode { + fn mask(&self, mut stats: DFStatistics) -> DFStatistics { + match self { + Self::NumRows => { + stats.num_rows = Precision::Absent; + } + Self::TotalByteSize => { + stats.total_byte_size = Precision::Absent; + } + Self::ColumnStatistics => { + let num_cols = stats.column_statistics.len(); + stats.column_statistics = vec![ColumnStatistics::new_unknown(); num_cols] + } + Self::Col(idx, mode) => { + let stats = &mut stats.column_statistics[*idx]; + + match mode { + ColMode::NullCount => { + stats.null_count = Precision::Absent; + } + ColMode::MaxValue => { + stats.max_value = Precision::Absent; + } + ColMode::MinValue => { + stats.min_value = Precision::Absent; + } + } + } + } + stats + } + } + + for mode in [ + Mode::NumRows, + Mode::TotalByteSize, + Mode::ColumnStatistics, + Mode::Col(0, ColMode::NullCount), + Mode::Col(0, ColMode::MaxValue), + Mode::Col(0, ColMode::MinValue), + Mode::Col(1, ColMode::NullCount), + ] { + println!("mode: {mode:?}"); + + for invalid_mask in [[false, true], [true, false], [true, true]] { + println!("invalid_mask: {invalid_mask:?}"); + let mut agg = DFStatsAggregator::new(&schema); + + for invalid in invalid_mask { + let mut update_stats = update_stats.clone(); + if invalid { + update_stats = mode.mask(update_stats); + } + agg.update(&update_stats, &schema); + } + + let actual = agg.build(); + + let expected = mode.mask(agg_stats.clone()); + assert_eq!(actual, expected); + } + } + } + + #[test] + #[should_panic(expected = "stats (0) and schema (1) have different column count")] + fn test_df_stats_agg_asserts_schema_stats_match() { + let schema = Schema::new(vec![Field::new("col1", DataType::UInt64, true)]); + let mut agg = DFStatsAggregator::new(&schema); + + let update_schema = Schema::new(vec![Field::new("col1", DataType::UInt64, true)]); + let update_stats = DFStatistics { + num_rows: Precision::Exact(1), + total_byte_size: Precision::Exact(10), + column_statistics: vec![], + }; + agg.update(&update_stats, &update_schema); + } + + #[test] + fn test_stats_for_one_chunk() { + // schema with one tag, one field, time and CHUNK_ORDER_COLUMN_NAME + let schema: SchemaRef = SchemaBuilder::new() + .tag("tag") + .influx_field("field", InfluxFieldType::Float) + .timestamp() + .influx_field(CHUNK_ORDER_COLUMN_NAME, InfluxFieldType::Integer) + .build() + .unwrap() + .into(); + + // create a test chunk with one tag, one filed, time and CHUNK_ORDER_COLUMN_NAME + let record_batch_chunk = Arc::new( + TestChunk::new("t") + .with_tag_column_with_stats("tag", Some("AL"), Some("MT")) + .with_time_column_with_stats(Some(10), Some(20)) + .with_i64_field_column_with_stats("field", Some(0), Some(100)) + .with_i64_field_column_with_stats(CHUNK_ORDER_COLUMN_NAME, Some(5), Some(6)), + ); + + // create them same test chunk but with a parquet file + let parquet_chunk = Arc::new( + TestChunk::new("t") + .with_tag_column_with_stats("tag", Some("AL"), Some("MT")) + .with_i64_field_column_with_stats("field", Some(0), Some(100)) + .with_time_column_with_stats(Some(10), Some(20)) + .with_i64_field_column_with_stats(CHUNK_ORDER_COLUMN_NAME, Some(5), Some(6)) + .with_dummy_parquet_file(), + ); + + let expected_stats = [ + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Utf8(Some("MT".to_string()))), + min_value: Precision::Exact(ScalarValue::Utf8(Some("AL".to_string()))), + distinct_count: Precision::Absent, + }, + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Int64(Some(100))), + min_value: Precision::Exact(ScalarValue::Int64(Some(0))), + distinct_count: Precision::Absent, + }, + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::TimestampNanosecond(Some(20), None)), + min_value: Precision::Exact(ScalarValue::TimestampNanosecond(Some(10), None)), + distinct_count: Precision::Absent, + }, + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Int64(Some(6))), + min_value: Precision::Exact(ScalarValue::Int64(Some(0))), + distinct_count: Precision::Absent, + }, + ]; + + let record_batch_stats = + build_statistics_for_chunks(&[record_batch_chunk], Arc::clone(&schema)); + assert_eq!(record_batch_stats.column_statistics, expected_stats); + + let parquet_stats = build_statistics_for_chunks(&[parquet_chunk], schema); + assert_eq!(parquet_stats.column_statistics, expected_stats); + } + + #[test] + fn test_stats_for_two_chunks() { + // schema with one tag, one field, time and CHUNK_ORDER_COLUMN_NAME + let schema: SchemaRef = SchemaBuilder::new() + .tag("tag") + .influx_field("field", InfluxFieldType::Float) + .timestamp() + .influx_field(CHUNK_ORDER_COLUMN_NAME, InfluxFieldType::Integer) + .build() + .unwrap() + .into(); + + // create a test chunk with one tag, one filed, time and CHUNK_ORDER_COLUMN_NAME + let record_batch_chunk_1 = Arc::new( + TestChunk::new("t1") + .with_tag_column_with_stats("tag", Some("AL"), Some("MT")) + .with_time_column_with_stats(Some(10), Some(20)) + .with_i64_field_column_with_stats("field", Some(0), Some(100)) + .with_i64_field_column_with_stats(CHUNK_ORDER_COLUMN_NAME, Some(5), Some(6)), + ); + + let record_batch_chunk_2 = Arc::new( + TestChunk::new("t2") + .with_tag_column_with_stats("tag", Some("MI"), Some("WA")) + .with_time_column_with_stats(Some(50), Some(80)) + .with_i64_field_column_with_stats("field", Some(0), Some(70)) + .with_i64_field_column_with_stats(CHUNK_ORDER_COLUMN_NAME, Some(7), Some(15)), + ); + + // create them same test chunk but with a parquet file + let parquet_chunk_1 = Arc::new( + TestChunk::new("t1") + .with_tag_column_with_stats("tag", Some("AL"), Some("MT")) + .with_i64_field_column_with_stats("field", Some(0), Some(100)) + .with_time_column_with_stats(Some(10), Some(20)) + .with_i64_field_column_with_stats(CHUNK_ORDER_COLUMN_NAME, Some(5), Some(6)) + .with_dummy_parquet_file(), + ); + + let parquet_chunk_2 = Arc::new( + TestChunk::new("t2") + .with_tag_column_with_stats("tag", Some("MI"), Some("WA")) + .with_i64_field_column_with_stats("field", Some(0), Some(70)) + .with_time_column_with_stats(Some(50), Some(80)) + .with_i64_field_column_with_stats(CHUNK_ORDER_COLUMN_NAME, Some(7), Some(15)) + .with_dummy_parquet_file(), + ); + + let expected_stats = [ + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Utf8(Some("WA".to_string()))), + min_value: Precision::Exact(ScalarValue::Utf8(Some("AL".to_string()))), + distinct_count: Precision::Absent, + }, + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Int64(Some(100))), + min_value: Precision::Exact(ScalarValue::Int64(Some(0))), + distinct_count: Precision::Absent, + }, + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::TimestampNanosecond(Some(80), None)), + min_value: Precision::Exact(ScalarValue::TimestampNanosecond(Some(10), None)), + distinct_count: Precision::Absent, + }, + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Int64(Some(15))), + min_value: Precision::Exact(ScalarValue::Int64(Some(0))), + distinct_count: Precision::Absent, + }, + ]; + + let record_batch_stats = build_statistics_for_chunks( + &[record_batch_chunk_1, record_batch_chunk_2], + Arc::clone(&schema), + ); + assert_eq!(record_batch_stats.column_statistics, expected_stats); + + let parquet_stats = + build_statistics_for_chunks(&[parquet_chunk_1, parquet_chunk_2], schema); + assert_eq!(parquet_stats.column_statistics, expected_stats); + } + + #[test] + fn test_compute_statistics_min_max() { + // schema with one tag, one field, time and CHUNK_ORDER_COLUMN_NAME + let schema: SchemaRef = SchemaBuilder::new() + .tag("tag") + .influx_field("float_field", InfluxFieldType::Float) + .influx_field("int_field", InfluxFieldType::Integer) + .influx_field("string_field", InfluxFieldType::String) + .tag("tag_no_val") // no chunks have values for this + .influx_field("field_no_val", InfluxFieldType::Integer) + .timestamp() + .build() + .unwrap() + .into(); + + let parquet_chunk = Arc::new( + TestChunk::new("t") + .with_time_column_with_stats(Some(10), Some(100)) + .with_tag_column_with_stats("tag", Some("MA"), Some("VT")) + .with_f64_field_column_with_stats("float_field", Some(10.1), Some(100.4)) + .with_i64_field_column_with_stats("int_field", Some(30), Some(50)) + .with_string_field_column_with_stats("string_field", Some("orange"), Some("plum")) + // only this chunk has value for this field + .with_i64_field_column_with_stats("field_no_val", Some(30), Some(50)) + .with_dummy_parquet_file(), + ) as Arc; + + let record_batch_chunk = Arc::new( + TestChunk::new("t") + .with_time_column_with_stats(Some(20), Some(200)) + .with_tag_column_with_stats("tag", Some("Boston"), Some("DC")) + .with_f64_field_column_with_stats("float_field", Some(15.6), Some(30.0)) + .with_i64_field_column_with_stats("int_field", Some(1), Some(50)) + .with_string_field_column_with_stats("string_field", Some("banana"), Some("plum")), + ) as Arc; + + let plan_pq = chunks_to_physical_nodes(&schema, None, vec![parquet_chunk], 1); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan_pq), + @r###" + --- + - " UnionExec" + - " ParquetExec: file_groups={1 group: [[0.parquet]]}, projection=[tag, float_field, int_field, string_field, tag_no_val, field_no_val, time]" + "### + ); + + let plan_rb = chunks_to_physical_nodes(&schema, None, vec![record_batch_chunk], 1); + insta::assert_yaml_snapshot!( + format_execution_plan(&plan_rb), + @r###" + --- + - " UnionExec" + - " RecordBatchesExec: chunks=1, projection=[tag, float_field, int_field, string_field, tag_no_val, field_no_val, time]" + "### + ); + + // Stats for time + // parquet + let time_stats = compute_stats_column_min_max(&*plan_pq, "time").unwrap(); + let min_max = column_statistics_min_max(&time_stats).unwrap(); + let expected_time_stats = ( + ScalarValue::TimestampNanosecond(Some(10), None), + ScalarValue::TimestampNanosecond(Some(100), None), + ); + assert_eq!(min_max, expected_time_stats); + // record batch + let time_stats = compute_stats_column_min_max(&*plan_rb, "time").unwrap(); + let min_max = column_statistics_min_max(&time_stats).unwrap(); + let expected_time_stats = ( + ScalarValue::TimestampNanosecond(Some(20), None), + ScalarValue::TimestampNanosecond(Some(200), None), + ); + assert_eq!(min_max, expected_time_stats); + + // Stats for tag + // parquet + let tag_stats = compute_stats_column_min_max(&*plan_pq, "tag").unwrap(); + let min_max = column_statistics_min_max(&tag_stats).unwrap(); + let expected_tag_stats = ( + ScalarValue::Utf8(Some("MA".to_string())), + ScalarValue::Utf8(Some("VT".to_string())), + ); + assert_eq!(min_max, expected_tag_stats); + // record batch + let tag_stats = compute_stats_column_min_max(&*plan_rb, "tag").unwrap(); + let min_max = column_statistics_min_max(&tag_stats).unwrap(); + let expected_tag_stats = ( + ScalarValue::Utf8(Some("Boston".to_string())), + ScalarValue::Utf8(Some("DC".to_string())), + ); + assert_eq!(min_max, expected_tag_stats); + + // Stats for field + // parquet + let float_stats = compute_stats_column_min_max(&*plan_pq, "float_field").unwrap(); + let min_max = column_statistics_min_max(&float_stats).unwrap(); + let expected_float_stats = ( + ScalarValue::Float64(Some(10.1)), + ScalarValue::Float64(Some(100.4)), + ); + assert_eq!(min_max, expected_float_stats); + // record batch + let float_stats = compute_stats_column_min_max(&*plan_rb, "float_field").unwrap(); + let min_max = column_statistics_min_max(&float_stats).unwrap(); + let expected_float_stats = ( + ScalarValue::Float64(Some(15.6)), + ScalarValue::Float64(Some(30.0)), + ); + assert_eq!(min_max, expected_float_stats); + + // Stats for int + // parquet + let int_stats = compute_stats_column_min_max(&*plan_pq, "int_field").unwrap(); + let min_max = column_statistics_min_max(&int_stats).unwrap(); + let expected_int_stats = (ScalarValue::Int64(Some(30)), ScalarValue::Int64(Some(50))); + assert_eq!(min_max, expected_int_stats); + // record batch + let int_stats = compute_stats_column_min_max(&*plan_rb, "int_field").unwrap(); + let min_max = column_statistics_min_max(&int_stats).unwrap(); + let expected_int_stats = (ScalarValue::Int64(Some(1)), ScalarValue::Int64(Some(50))); + assert_eq!(min_max, expected_int_stats); + + // Stats for string + // parquet + let string_stats = compute_stats_column_min_max(&*plan_pq, "string_field").unwrap(); + let min_max = column_statistics_min_max(&string_stats).unwrap(); + let expected_string_stats = ( + ScalarValue::Utf8(Some("orange".to_string())), + ScalarValue::Utf8(Some("plum".to_string())), + ); + assert_eq!(min_max, expected_string_stats); + // record batch + let string_stats = compute_stats_column_min_max(&*plan_rb, "string_field").unwrap(); + let min_max = column_statistics_min_max(&string_stats).unwrap(); + let expected_string_stats = ( + ScalarValue::Utf8(Some("banana".to_string())), + ScalarValue::Utf8(Some("plum".to_string())), + ); + assert_eq!(min_max, expected_string_stats); + + // no tats on parquet + let tag_no_stats = compute_stats_column_min_max(&*plan_pq, "tag_no_val").unwrap(); + let min_max = column_statistics_min_max(&tag_no_stats); + assert!(min_max.is_none()); + + // no stats on record batch + let field_no_stats = compute_stats_column_min_max(&*plan_rb, "field_no_val").unwrap(); + let min_max = column_statistics_min_max(&field_no_stats); + assert!(min_max.is_none()); + } + + #[test] + fn test_statistics_min_max() { + // schema with one tag, one field, time and CHUNK_ORDER_COLUMN_NAME + let schema: SchemaRef = SchemaBuilder::new() + .tag("tag") + .influx_field("float_field", InfluxFieldType::Float) + .influx_field("int_field", InfluxFieldType::Integer) + .influx_field("string_field", InfluxFieldType::String) + .tag("tag_no_val") // no chunks have values for this + .influx_field("field_no_val", InfluxFieldType::Integer) + .timestamp() + .build() + .unwrap() + .into(); + + let parquet_chunk = Arc::new( + TestChunk::new("t") + .with_time_column_with_stats(Some(10), Some(100)) + .with_tag_column_with_stats("tag", Some("MA"), Some("VT")) + .with_f64_field_column_with_stats("float_field", Some(10.1), Some(100.4)) + .with_i64_field_column_with_stats("int_field", Some(30), Some(50)) + .with_string_field_column_with_stats("string_field", Some("orange"), Some("plum")) + // only this chunk has value for this field + .with_i64_field_column_with_stats("field_no_val", Some(30), Some(50)) + .with_dummy_parquet_file(), + ) as Arc; + + let record_batch_chunk = Arc::new( + TestChunk::new("t") + .with_time_column_with_stats(Some(20), Some(200)) + .with_tag_column_with_stats("tag", Some("Boston"), Some("DC")) + .with_f64_field_column_with_stats("float_field", Some(15.6), Some(30.0)) + .with_i64_field_column_with_stats("int_field", Some(1), Some(50)) + .with_string_field_column_with_stats("string_field", Some("banana"), Some("plum")), + ) as Arc; + + let plan1 = chunks_to_physical_nodes(&schema, None, vec![parquet_chunk], 1); + let plan2 = chunks_to_physical_nodes(&schema, None, vec![record_batch_chunk], 1); + + let time_stats = + statistics_min_max(&[Arc::clone(&plan1), Arc::clone(&plan2)], "time").unwrap(); + let expected_time_stats = [ + ( + ScalarValue::TimestampNanosecond(Some(10), None), + ScalarValue::TimestampNanosecond(Some(100), None), + ), + ( + ScalarValue::TimestampNanosecond(Some(20), None), + ScalarValue::TimestampNanosecond(Some(200), None), + ), + ]; + assert_eq!(time_stats, expected_time_stats); + + let tag_stats = + statistics_min_max(&[Arc::clone(&plan1), Arc::clone(&plan2)], "tag").unwrap(); + let expected_tag_stats = [ + ( + ScalarValue::Utf8(Some("MA".to_string())), + ScalarValue::Utf8(Some("VT".to_string())), + ), + ( + ScalarValue::Utf8(Some("Boston".to_string())), + ScalarValue::Utf8(Some("DC".to_string())), + ), + ]; + assert_eq!(tag_stats, expected_tag_stats); + + let float_stats = + statistics_min_max(&[Arc::clone(&plan1), Arc::clone(&plan2)], "float_field").unwrap(); + let expected_float_stats = [ + ( + ScalarValue::Float64(Some(10.1)), + ScalarValue::Float64(Some(100.4)), + ), + ( + ScalarValue::Float64(Some(15.6)), + ScalarValue::Float64(Some(30.0)), + ), + ]; + assert_eq!(float_stats, expected_float_stats); + + let int_stats = + statistics_min_max(&[Arc::clone(&plan1), Arc::clone(&plan2)], "int_field").unwrap(); + let expected_int_stats = [ + (ScalarValue::Int64(Some(30)), ScalarValue::Int64(Some(50))), + (ScalarValue::Int64(Some(1)), ScalarValue::Int64(Some(50))), + ]; + assert_eq!(int_stats, expected_int_stats); + + let string_stats = + statistics_min_max(&[Arc::clone(&plan1), Arc::clone(&plan2)], "string_field").unwrap(); + let expected_string_stats = [ + ( + ScalarValue::Utf8(Some("orange".to_string())), + ScalarValue::Utf8(Some("plum".to_string())), + ), + ( + ScalarValue::Utf8(Some("banana".to_string())), + ScalarValue::Utf8(Some("plum".to_string())), + ), + ]; + assert_eq!(string_stats, expected_string_stats); + + let tag_no_stat = + statistics_min_max(&[Arc::clone(&plan1), Arc::clone(&plan2)], "tag_no_val"); + assert!(tag_no_stat.is_none()); + + let field_no_stat = + statistics_min_max(&[Arc::clone(&plan1), Arc::clone(&plan2)], "field_no_val"); + assert!(field_no_stat.is_none()); + } + + #[test] + fn test_non_overlap_time() { + let pair_1 = ( + ScalarValue::TimestampNanosecond(Some(10), None), + ScalarValue::TimestampNanosecond(Some(20), None), + ); + let pair_2 = ( + ScalarValue::TimestampNanosecond(Some(100), None), + ScalarValue::TimestampNanosecond(Some(150), None), + ); + let pair_3 = ( + ScalarValue::TimestampNanosecond(Some(60), None), + ScalarValue::TimestampNanosecond(Some(65), None), + ); + + let overlap = overlap_all(&vec![pair_1, pair_2, pair_3]).unwrap(); + assert!(!overlap); + } + + #[test] + fn test_overlap_time() { + let pair_1 = ( + ScalarValue::TimestampNanosecond(Some(10), None), + ScalarValue::TimestampNanosecond(Some(20), None), + ); + let pair_2 = ( + ScalarValue::TimestampNanosecond(Some(100), None), + ScalarValue::TimestampNanosecond(Some(150), None), + ); + let pair_3 = ( + ScalarValue::TimestampNanosecond(Some(8), None), + ScalarValue::TimestampNanosecond(Some(10), None), + ); + + let overlap = overlap_all(&vec![pair_1, pair_2, pair_3]).unwrap(); + assert!(overlap); + } + + #[test] + fn test_non_overlap_integer() { + // [3, 5, 9, 12, 1, 1, 6, 8] + let pair_1 = (ScalarValue::Int16(Some(3)), ScalarValue::Int16(Some(5))); + let pair_2 = (ScalarValue::Int16(Some(9)), ScalarValue::Int16(Some(12))); + let pair_3 = (ScalarValue::Int16(Some(1)), ScalarValue::Int16(Some(1))); + let pair_4 = (ScalarValue::Int16(Some(6)), ScalarValue::Int16(Some(8))); + + let overlap = overlap_all(&vec![pair_1, pair_2, pair_3, pair_4]).unwrap(); + assert!(!overlap); + } + + #[test] + fn test_overlap_integer() { + // [3, 5, 9, 12, 1, 1, 4, 6] + let pair_1 = (ScalarValue::Int16(Some(3)), ScalarValue::Int16(Some(5))); + let pair_2 = (ScalarValue::Int16(Some(9)), ScalarValue::Int16(Some(12))); + let pair_3 = (ScalarValue::Int16(Some(1)), ScalarValue::Int16(Some(1))); + let pair_4 = (ScalarValue::Int16(Some(4)), ScalarValue::Int16(Some(6))); + + let overlap = overlap_all(&vec![pair_1, pair_2, pair_3, pair_4]).unwrap(); + assert!(overlap); + } + + #[test] + fn test_non_overlap_integer_ascending_null_first() { + // [3, 5, null, null, 1, 1, 6, 8] + let pair_1 = (ScalarValue::Int16(Some(3)), ScalarValue::Int16(Some(5))); + let pair_2 = (ScalarValue::Int16(None), ScalarValue::Int16(None)); + let pair_3 = (ScalarValue::Int16(Some(1)), ScalarValue::Int16(Some(2))); + let pair_4 = (ScalarValue::Int16(Some(6)), ScalarValue::Int16(Some(8))); + + let overlap = overlap_all(&vec![pair_1, pair_2, pair_3, pair_4]).unwrap(); + assert!(!overlap); + } + + #[test] + fn test_overlap_integer_ascending_null_first() { + // [3, 5, null, null, 1, 1, 4, 6] + let pair_1 = (ScalarValue::Int16(Some(3)), ScalarValue::Int16(Some(5))); + let pair_2 = (ScalarValue::Int16(None), ScalarValue::Int16(None)); + let pair_3 = (ScalarValue::Int16(Some(1)), ScalarValue::Int16(Some(2))); + let pair_4 = (ScalarValue::Int16(Some(4)), ScalarValue::Int16(Some(6))); + + let overlap = overlap_all(&vec![pair_1, pair_2, pair_3, pair_4]).unwrap(); + assert!(overlap); + } + + #[test] + fn test_non_overlap_string_ascending_null_first() { + // ['e', 'h', null, null, 'a', 'a', 'k', 'q'] + let pair_1 = ( + ScalarValue::Utf8(Some('e'.to_string())), + ScalarValue::Utf8(Some('h'.to_string())), + ); + let pair_2 = (ScalarValue::Utf8(None), ScalarValue::Utf8(None)); + let pair_3 = ( + ScalarValue::Utf8(Some('a'.to_string())), + ScalarValue::Utf8(Some('a'.to_string())), + ); + + let overlap = overlap_all(&vec![pair_1, pair_2, pair_3]).unwrap(); + assert!(!overlap); + } + + #[test] + fn test_overlap_string_ascending_null_first() { + // ['e', 'h', null, null, 'a', 'f', 'k', 'q'] + let pair_1 = ( + ScalarValue::Utf8(Some('e'.to_string())), + ScalarValue::Utf8(Some('h'.to_string())), + ); + let pair_2 = (ScalarValue::Utf8(None), ScalarValue::Utf8(None)); + let pair_3 = ( + ScalarValue::Utf8(Some('a'.to_string())), + ScalarValue::Utf8(Some('f'.to_string())), + ); + + let overlap = overlap_all(&vec![pair_1, pair_2, pair_3]).unwrap(); + assert!(overlap); + } + + #[test] + #[should_panic(expected = "Internal(\"Empty iterator passed to ScalarValue::iter_to_array\")")] + fn test_overlap_empty() { + let _overlap = overlap_all(&[]); + } + + #[should_panic(expected = "min (Int16(3)) > max (Int16(2))")] + #[test] + fn test_overlap_panic() { + // max < min + let pair_1 = (ScalarValue::Int16(Some(3)), ScalarValue::Int16(Some(2))); + let _overlap = overlap_all(&[pair_1]); + } + + /// Runs `overlap` on all permutations of the given `value_range`es and asserts that the result is + /// the same. Returns that result + fn overlap_all(value_ranges: &[(ScalarValue, ScalarValue)]) -> Result { + let n = value_ranges.len(); + + let mut overlaps_all_permutations = value_ranges + .iter() + .cloned() + .permutations(n) + .map(|v| overlap(&v)); + + let Some(first) = overlaps_all_permutations.next() else { + return overlap(value_ranges); + }; + + let first = first.unwrap(); + + for result in overlaps_all_permutations { + assert_eq!(&result.unwrap(), &first); + } + + Ok(first) + } +} diff --git a/iox_query/src/test.rs b/iox_query/src/test.rs new file mode 100644 index 0000000..e969776 --- /dev/null +++ b/iox_query/src/test.rs @@ -0,0 +1,1220 @@ +//! This module provides a reference implementation of [`QueryNamespace`] for use in testing. +//! +//! AKA it is a Mock + +use crate::{ + exec::{ + stringset::{StringSet, StringSetRef}, + Executor, ExecutorType, IOxSessionContext, + }, + pruning::prune_chunks, + query_log::{QueryLog, StateReceived}, + QueryChunk, QueryChunkData, QueryCompletedToken, QueryNamespace, QueryNamespaceProvider, + QueryText, +}; +use arrow::array::{BooleanArray, Float64Array}; +use arrow::datatypes::SchemaRef; +use arrow::{ + array::{ + ArrayRef, DictionaryArray, Int64Array, StringArray, TimestampNanosecondArray, UInt64Array, + }, + datatypes::{DataType, Int32Type, TimeUnit}, + record_batch::RecordBatch, +}; +use async_trait::async_trait; +use data_types::{ChunkId, ChunkOrder, NamespaceId, PartitionKey, TableId, TransitionPartitionId}; +use datafusion::common::stats::Precision; +use datafusion::error::DataFusionError; +use datafusion::execution::context::SessionState; +use datafusion::logical_expr::Expr; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::{catalog::schema::SchemaProvider, logical_expr::LogicalPlan}; +use datafusion::{catalog::CatalogProvider, physical_plan::displayable}; +use datafusion::{ + datasource::{object_store::ObjectStoreUrl, TableProvider, TableType}, + physical_plan::{ColumnStatistics, Statistics as DataFusionStatistics}, + scalar::ScalarValue, +}; +use datafusion_util::{config::DEFAULT_SCHEMA, option_to_precision, timestamptz_nano}; +use iox_time::SystemProvider; +use itertools::Itertools; +use object_store::{path::Path, ObjectMeta}; +use parking_lot::Mutex; +use parquet_file::storage::ParquetExecInput; +use schema::{ + builder::SchemaBuilder, merge::SchemaMerger, sort::SortKey, Schema, TIME_COLUMN_NAME, +}; +use std::{ + any::Any, + collections::{BTreeMap, HashMap}, + fmt, + num::NonZeroU64, + sync::Arc, +}; +use trace::{ctx::SpanContext, span::Span}; +use tracker::{AsyncSemaphoreMetrics, InstrumentedAsyncOwnedSemaphorePermit}; + +#[derive(Debug)] +pub struct TestDatabaseStore { + databases: Mutex>>, + executor: Arc, + pub metric_registry: Arc, + pub query_semaphore: Arc, +} + +impl TestDatabaseStore { + pub fn new() -> Self { + Self::default() + } + + pub fn new_with_semaphore_size(semaphore_size: usize) -> Self { + let metric_registry = Arc::new(metric::Registry::default()); + let semaphore_metrics = Arc::new(AsyncSemaphoreMetrics::new( + &metric_registry, + &[("semaphore", "query_execution")], + )); + Self { + databases: Mutex::new(BTreeMap::new()), + executor: Arc::new(Executor::new_testing()), + metric_registry, + query_semaphore: Arc::new(semaphore_metrics.new_semaphore(semaphore_size)), + } + } + + pub async fn db_or_create(&self, name: &str) -> Arc { + let mut databases = self.databases.lock(); + + if let Some(db) = databases.get(name) { + Arc::clone(db) + } else { + let new_db = Arc::new(TestDatabase::new(Arc::clone(&self.executor))); + databases.insert(name.to_string(), Arc::clone(&new_db)); + new_db + } + } +} + +impl Default for TestDatabaseStore { + fn default() -> Self { + Self::new_with_semaphore_size(u16::MAX as usize) + } +} + +#[async_trait] +impl QueryNamespaceProvider for TestDatabaseStore { + /// Retrieve the database specified name + async fn db( + &self, + name: &str, + _span: Option, + _include_debug_info_tables: bool, + ) -> Option> { + let databases = self.databases.lock(); + + databases.get(name).cloned().map(|ns| ns as _) + } + + async fn acquire_semaphore(&self, span: Option) -> InstrumentedAsyncOwnedSemaphorePermit { + Arc::clone(&self.query_semaphore) + .acquire_owned(span) + .await + .unwrap() + } +} + +#[derive(Debug)] +pub struct TestDatabase { + executor: Arc, + /// Partitions which have been saved to this test database + /// Key is partition name + /// Value is map of chunk_id to chunk + partitions: Mutex>>>, + + /// `column_names` to return upon next request + column_names: Arc>>, + + /// The predicate passed to the most recent call to `chunks()` + chunks_predicate: Mutex>, + + /// Retention time ns. + retention_time_ns: Option, +} + +impl TestDatabase { + pub fn new(executor: Arc) -> Self { + Self { + executor, + partitions: Default::default(), + column_names: Default::default(), + chunks_predicate: Default::default(), + retention_time_ns: None, + } + } + + /// Add a test chunk to the database + pub fn add_chunk(&self, partition_key: &str, chunk: Arc) -> &Self { + let mut partitions = self.partitions.lock(); + let chunks = partitions.entry(partition_key.to_string()).or_default(); + chunks.insert(chunk.id(), chunk); + self + } + + /// Add a test chunk to the database + pub fn with_chunk(self, partition_key: &str, chunk: Arc) -> Self { + self.add_chunk(partition_key, chunk); + self + } + + /// Get the specified chunk + pub fn get_chunk(&self, partition_key: &str, id: ChunkId) -> Option> { + self.partitions + .lock() + .get(partition_key) + .and_then(|p| p.get(&id).cloned()) + } + + /// Return the most recent predicate passed to get_chunks() + pub fn get_chunks_predicate(&self) -> Vec { + self.chunks_predicate.lock().clone() + } + + /// Set the list of column names that will be returned on a call to + /// column_names + pub fn set_column_names(&self, column_names: Vec) { + let column_names = column_names.into_iter().collect::(); + let column_names = Arc::new(column_names); + + *Arc::clone(&self.column_names).lock() = Some(column_names) + } + + /// Set retention time. + pub fn with_retention_time_ns(mut self, retention_time_ns: Option) -> Self { + self.retention_time_ns = retention_time_ns; + self + } +} + +#[async_trait] +impl QueryNamespace for TestDatabase { + async fn chunks( + &self, + table_name: &str, + filters: &[Expr], + _projection: Option<&Vec>, + _ctx: IOxSessionContext, + ) -> Result>, DataFusionError> { + // save last predicate + *self.chunks_predicate.lock() = filters.to_vec(); + + let partitions = self.partitions.lock().clone(); + Ok(partitions + .values() + .flat_map(|x| x.values()) + // filter by table + .filter(|c| c.table_name == table_name) + // only keep chunks if their statistics overlap + .filter(|c| { + prune_chunks( + c.schema(), + &[Arc::clone(*c) as Arc], + filters, + ) + .ok() + .map(|res| res[0]) + .unwrap_or(true) + }) + .map(|x| Arc::clone(x) as Arc) + .collect::>()) + } + + fn retention_time_ns(&self) -> Option { + self.retention_time_ns + } + + fn record_query( + &self, + span_ctx: Option<&SpanContext>, + query_type: &'static str, + query_text: QueryText, + ) -> QueryCompletedToken { + QueryLog::new(0, Arc::new(SystemProvider::new())).push( + NamespaceId::new(1), + Arc::from("ns"), + query_type, + query_text, + span_ctx.map(|s| s.trace_id), + ) + } + + fn new_query_context(&self, span_ctx: Option) -> IOxSessionContext { + // Note: unlike Db this does not register a catalog provider + self.executor + .new_execution_config(ExecutorType::Query) + .with_default_catalog(Arc::new(TestDatabaseCatalogProvider::from_test_database( + self, + ))) + .with_span_context(span_ctx) + .build() + } +} + +struct TestDatabaseCatalogProvider { + partitions: BTreeMap>>, +} + +impl TestDatabaseCatalogProvider { + fn from_test_database(db: &TestDatabase) -> Self { + Self { + partitions: db.partitions.lock().clone(), + } + } +} + +impl CatalogProvider for TestDatabaseCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + vec![DEFAULT_SCHEMA.to_string()] + } + + fn schema(&self, name: &str) -> Option> { + match name { + DEFAULT_SCHEMA => Some(Arc::new(TestDatabaseSchemaProvider { + partitions: self.partitions.clone(), + })), + _ => None, + } + } +} + +struct TestDatabaseSchemaProvider { + partitions: BTreeMap>>, +} + +#[async_trait] +impl SchemaProvider for TestDatabaseSchemaProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + self.partitions + .values() + .flat_map(|c| c.values()) + .map(|c| c.table_name.to_owned()) + .unique() + .collect() + } + + async fn table(&self, name: &str) -> Option> { + Some(Arc::new(TestDatabaseTableProvider { + partitions: self + .partitions + .values() + .flat_map(|chunks| chunks.values().filter(|c| c.table_name() == name)) + .map(Clone::clone) + .collect(), + })) + } + + fn table_exist(&self, name: &str) -> bool { + self.table_names().contains(&name.to_string()) + } +} + +struct TestDatabaseTableProvider { + partitions: Vec>, +} + +#[async_trait] +impl TableProvider for TestDatabaseTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.partitions + .iter() + .fold(SchemaMerger::new(), |merger, chunk| { + merger.merge(chunk.schema()).expect("consistent schemas") + }) + .build() + .as_arrow() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _ctx: &SessionState, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> crate::exec::context::Result> { + unimplemented!() + } +} + +#[derive(Debug, Clone)] +enum TestChunkData { + RecordBatches(Vec), + Parquet(ParquetExecInput), +} + +#[derive(Debug, Clone)] +pub struct TestChunk { + /// Table name + table_name: String, + + /// Schema of the table + schema: Schema, + + /// Values for stats() + column_stats: HashMap, + num_rows: Option, + + id: ChunkId, + + partition_id: TransitionPartitionId, + + /// Set the flag if this chunk might contain duplicates + may_contain_pk_duplicates: bool, + + /// Data in this chunk. + table_data: TestChunkData, + + /// A saved error that is returned instead of actual results + saved_error: Option, + + /// Order of this chunk relative to other overlapping chunks. + order: ChunkOrder, + + /// The sort key of this chunk + sort_key: Option, + + /// Suppress output + quiet: bool, +} + +/// Implements a method for adding a column with default stats +macro_rules! impl_with_column { + ($NAME:ident, $DATA_TYPE:ident) => { + pub fn $NAME(self, column_name: impl Into) -> Self { + let column_name = column_name.into(); + + let new_column_schema = SchemaBuilder::new() + .field(&column_name, DataType::$DATA_TYPE) + .unwrap() + .build() + .unwrap(); + self.add_schema_to_table(new_column_schema, None) + } + }; +} + +/// Implements a method for adding a column with stats that have the specified min and max +macro_rules! impl_with_column_with_stats { + ($NAME:ident, $DATA_TYPE:ident, $RUST_TYPE:ty, $STAT_TYPE:ident) => { + pub fn $NAME( + self, + column_name: impl Into, + min: Option<$RUST_TYPE>, + max: Option<$RUST_TYPE>, + ) -> Self { + let column_name = column_name.into(); + + let new_column_schema = SchemaBuilder::new() + .field(&column_name, DataType::$DATA_TYPE) + .unwrap() + .build() + .unwrap(); + + let stats = ColumnStatistics { + null_count: Precision::Absent, + max_value: option_to_precision(max.map(|s| ScalarValue::from(s))), + min_value: option_to_precision(min.map(|s| ScalarValue::from(s))), + distinct_count: Precision::Absent, + }; + + self.add_schema_to_table(new_column_schema, Some(stats)) + } + }; +} + +impl TestChunk { + pub fn new(table_name: impl Into) -> Self { + let table_name = table_name.into(); + Self { + table_name, + schema: SchemaBuilder::new().build().unwrap(), + column_stats: Default::default(), + num_rows: None, + id: ChunkId::new_test(0), + may_contain_pk_duplicates: Default::default(), + table_data: TestChunkData::RecordBatches(vec![]), + saved_error: Default::default(), + order: ChunkOrder::MIN, + sort_key: None, + partition_id: TransitionPartitionId::arbitrary_for_testing(), + quiet: false, + } + } + + fn push_record_batch(&mut self, batch: RecordBatch) { + match &mut self.table_data { + TestChunkData::RecordBatches(batches) => { + batches.push(batch); + } + TestChunkData::Parquet(_) => panic!("chunk is parquet-based"), + } + } + + pub fn with_order(self, order: i64) -> Self { + Self { + order: ChunkOrder::new(order), + ..self + } + } + + pub fn with_dummy_parquet_file(self) -> Self { + self.with_dummy_parquet_file_and_store("iox://store") + } + + pub fn with_dummy_parquet_file_and_size(self, size: usize) -> Self { + self.with_dummy_parquet_file_and_store_and_size("iox://store", size) + } + + pub fn with_dummy_parquet_file_and_store(self, store: &str) -> Self { + self.with_dummy_parquet_file_and_store_and_size(store, 1) + } + + pub fn with_dummy_parquet_file_and_store_and_size(self, store: &str, size: usize) -> Self { + match self.table_data { + TestChunkData::RecordBatches(batches) => { + assert!(batches.is_empty(), "chunk already has record batches"); + } + TestChunkData::Parquet(_) => panic!("chunk already has a file"), + } + + Self { + table_data: TestChunkData::Parquet(ParquetExecInput { + object_store_url: ObjectStoreUrl::parse(store).unwrap(), + object_meta: ObjectMeta { + location: Self::parquet_location(self.id), + last_modified: Default::default(), + size, + e_tag: None, + version: None, + }, + }), + ..self + } + } + + fn parquet_location(chunk_id: ChunkId) -> Path { + Path::parse(format!("{}.parquet", chunk_id.get().as_u128())).unwrap() + } + + /// Returns the receiver configured to suppress any output to STDOUT. + pub fn with_quiet(mut self) -> Self { + self.quiet = true; + self + } + + pub fn with_id(mut self, id: u128) -> Self { + self.id = ChunkId::new_test(id); + + if let TestChunkData::Parquet(parquet_input) = &mut self.table_data { + parquet_input.object_meta.location = Self::parquet_location(self.id); + } + + self + } + + pub fn with_partition(mut self, id: i64) -> Self { + self.partition_id = + TransitionPartitionId::new(TableId::new(id), &PartitionKey::from("arbitrary")); + self + } + + pub fn with_partition_id(mut self, id: TransitionPartitionId) -> Self { + self.partition_id = id; + self + } + + /// specify that any call should result in an error with the message + /// specified + pub fn with_error(mut self, error_message: impl Into) -> Self { + self.saved_error = Some(error_message.into()); + self + } + + /// Checks the saved error, and returns it if any, otherwise returns OK + fn check_error(&self) -> Result<(), DataFusionError> { + if let Some(message) = self.saved_error.as_ref() { + Err(DataFusionError::External(message.clone().into())) + } else { + Ok(()) + } + } + + /// Set the `may_contain_pk_duplicates` flag + pub fn with_may_contain_pk_duplicates(mut self, v: bool) -> Self { + self.may_contain_pk_duplicates = v; + self + } + + /// Register a tag column with the test chunk with default stats + pub fn with_tag_column(self, column_name: impl Into) -> Self { + let column_name = column_name.into(); + + // make a new schema with the specified column and + // merge it in to any existing schema + let new_column_schema = SchemaBuilder::new().tag(&column_name).build().unwrap(); + + self.add_schema_to_table(new_column_schema, None) + } + + /// Register a tag column with stats with the test chunk + pub fn with_tag_column_with_stats( + self, + column_name: impl Into, + min: Option<&str>, + max: Option<&str>, + ) -> Self { + self.with_tag_column_with_full_stats(column_name, min, max, 0, None) + } + + /// Register a tag column with stats with the test chunk + pub fn with_tag_column_with_full_stats( + self, + column_name: impl Into, + min: Option<&str>, + max: Option<&str>, + count: u64, + distinct_count: Option, + ) -> Self { + let null_count = 0; + self.with_tag_column_with_nulls_and_full_stats( + column_name, + min, + max, + count, + distinct_count, + null_count, + ) + } + + fn update_count(&mut self, count: usize) { + match self.num_rows { + Some(existing) => assert_eq!(existing, count), + None => self.num_rows = Some(count), + } + } + + /// Register a tag column with stats with the test chunk + pub fn with_tag_column_with_nulls_and_full_stats( + mut self, + column_name: impl Into, + min: Option<&str>, + max: Option<&str>, + count: u64, + distinct_count: Option, + null_count: u64, + ) -> Self { + let column_name = column_name.into(); + + // make a new schema with the specified column and + // merge it in to any existing schema + let new_column_schema = SchemaBuilder::new().tag(&column_name).build().unwrap(); + + // Construct stats + let stats = ColumnStatistics { + null_count: Precision::Exact(null_count as usize), + max_value: option_to_precision(max.map(ScalarValue::from)), + min_value: option_to_precision(min.map(ScalarValue::from)), + distinct_count: option_to_precision(distinct_count.map(|c| c.get() as usize)), + }; + + self.update_count(count as usize); + self.add_schema_to_table(new_column_schema, Some(stats)) + } + + /// Register a timestamp column with the test chunk with default stats + pub fn with_time_column(self) -> Self { + // make a new schema with the specified column and + // merge it in to any existing schema + let new_column_schema = SchemaBuilder::new().timestamp().build().unwrap(); + + self.add_schema_to_table(new_column_schema, None) + } + + /// Register a timestamp column with the test chunk + pub fn with_time_column_with_stats(self, min: Option, max: Option) -> Self { + self.with_time_column_with_full_stats(min, max, 0, None) + } + + /// Register a timestamp column with full stats with the test chunk + pub fn with_time_column_with_full_stats( + mut self, + min: Option, + max: Option, + count: u64, + distinct_count: Option, + ) -> Self { + // make a new schema with the specified column and + // merge it in to any existing schema + let new_column_schema = SchemaBuilder::new().timestamp().build().unwrap(); + let null_count = 0; + + // Construct stats + let stats = ColumnStatistics { + null_count: Precision::Exact(null_count as usize), + max_value: option_to_precision(max.map(timestamptz_nano)), + min_value: option_to_precision(min.map(timestamptz_nano)), + distinct_count: option_to_precision(distinct_count.map(|c| c.get() as usize)), + }; + + self.update_count(count as usize); + self.add_schema_to_table(new_column_schema, Some(stats)) + } + + pub fn with_timestamp_min_max(mut self, min: i64, max: i64) -> Self { + let stats = self + .column_stats + .get_mut(TIME_COLUMN_NAME) + .expect("stats in sync w/ columns"); + + stats.min_value = Precision::Exact(timestamptz_nano(min)); + stats.max_value = Precision::Exact(timestamptz_nano(max)); + + self + } + + impl_with_column!(with_i64_field_column, Int64); + impl_with_column_with_stats!(with_i64_field_column_with_stats, Int64, i64, I64); + + impl_with_column!(with_u64_column, UInt64); + impl_with_column_with_stats!(with_u64_field_column_with_stats, UInt64, u64, U64); + + impl_with_column!(with_f64_field_column, Float64); + impl_with_column_with_stats!(with_f64_field_column_with_stats, Float64, f64, F64); + + impl_with_column!(with_bool_field_column, Boolean); + impl_with_column_with_stats!(with_bool_field_column_with_stats, Boolean, bool, Bool); + + /// Register a string field column with the test chunk + pub fn with_string_field_column_with_stats( + self, + column_name: impl Into, + min: Option<&str>, + max: Option<&str>, + ) -> Self { + let column_name = column_name.into(); + + // make a new schema with the specified column and + // merge it in to any existing schema + let new_column_schema = SchemaBuilder::new() + .field(&column_name, DataType::Utf8) + .unwrap() + .build() + .unwrap(); + + // Construct stats + let stats = ColumnStatistics { + null_count: Precision::Absent, + max_value: option_to_precision(max.map(ScalarValue::from)), + min_value: option_to_precision(min.map(ScalarValue::from)), + distinct_count: Precision::Absent, + }; + + self.add_schema_to_table(new_column_schema, Some(stats)) + } + + /// Adds the specified schema and optionally a column summary containing optional stats. + /// If `add_column_summary` is false, `stats` is ignored. If `add_column_summary` is true but + /// `stats` is `None`, default stats will be added to the column summary. + fn add_schema_to_table( + mut self, + new_column_schema: Schema, + input_stats: Option, + ) -> Self { + let mut merger = SchemaMerger::new(); + merger = merger.merge(&new_column_schema).unwrap(); + merger = merger.merge(&self.schema).expect("merging was successful"); + self.schema = merger.build(); + + for f in new_column_schema.inner().fields() { + self.column_stats.insert( + f.name().clone(), + input_stats.as_ref().cloned().unwrap_or_default(), + ); + } + + self + } + + /// Prepares this chunk to return a specific record batch with one + /// row of non null data. + /// tag: MA + pub fn with_one_row_of_data(mut self) -> Self { + // create arrays + let columns = self + .schema + .iter() + .map(|(_influxdb_column_type, field)| match field.data_type() { + DataType::Int64 => Arc::new(Int64Array::from(vec![1000])) as ArrayRef, + DataType::UInt64 => Arc::new(UInt64Array::from(vec![1000])) as ArrayRef, + DataType::Utf8 => Arc::new(StringArray::from(vec!["MA"])) as ArrayRef, + DataType::Timestamp(TimeUnit::Nanosecond, tz) => Arc::new( + TimestampNanosecondArray::from(vec![1000]).with_timezone_opt(tz.clone()), + ) as ArrayRef, + DataType::Dictionary(key, value) + if key.as_ref() == &DataType::Int32 && value.as_ref() == &DataType::Utf8 => + { + let dict: DictionaryArray = vec!["MA"].into_iter().collect(); + Arc::new(dict) as ArrayRef + } + DataType::Float64 => Arc::new(Float64Array::from(vec![99.5])) as ArrayRef, + DataType::Boolean => Arc::new(BooleanArray::from(vec![true])) as ArrayRef, + _ => unimplemented!( + "Unimplemented data type for test database: {:?}", + field.data_type() + ), + }) + .collect::>(); + + let batch = + RecordBatch::try_new(self.schema.as_arrow(), columns).expect("made record batch"); + if !self.quiet { + println!("TestChunk batch data: {batch:#?}"); + } + + self.push_record_batch(batch); + self + } + + /// Prepares this chunk to return a specific record batch with a single tag, field and timestamp like + pub fn with_one_row_of_specific_data( + mut self, + tag_val: impl AsRef, + field_val: i64, + ts_val: i64, + ) -> Self { + // create arrays + let columns = self + .schema + .iter() + .map(|(_influxdb_column_type, field)| match field.data_type() { + DataType::Int64 => Arc::new(Int64Array::from(vec![field_val])) as ArrayRef, + DataType::Timestamp(TimeUnit::Nanosecond, tz) => Arc::new( + TimestampNanosecondArray::from(vec![ts_val]).with_timezone_opt(tz.clone()), + ) as ArrayRef, + DataType::Dictionary(key, value) + if key.as_ref() == &DataType::Int32 && value.as_ref() == &DataType::Utf8 => + { + let dict: DictionaryArray = + vec![tag_val.as_ref()].into_iter().collect(); + Arc::new(dict) as ArrayRef + } + _ => unimplemented!( + "Unimplemented data type for test database: {:?}", + field.data_type() + ), + }) + .collect::>(); + + let batch = + RecordBatch::try_new(self.schema.as_arrow(), columns).expect("made record batch"); + if !self.quiet { + println!("TestChunk batch data: {batch:#?}"); + } + + self.push_record_batch(batch); + self + } + + /// Prepares this chunk to return a specific record batch with three + /// rows of non null data that look like, no duplicates within + /// "+------+------+-----------+-------------------------------+", + /// "| tag1 | tag2 | field_int | time |", + /// "+------+------+-----------+-------------------------------+", + /// "| WA | SC | 1000 | 1970-01-01 00:00:00.000008 |", + /// "| VT | NC | 10 | 1970-01-01 00:00:00.000010 |", + /// "| UT | RI | 70 | 1970-01-01 00:00:00.000020 |", + /// "+------+------+-----------+-------------------------------+", + /// Stats(min, max) : tag1(UT, WA), tag2(RI, SC), time(8000, 20000) + pub fn with_three_rows_of_data(mut self) -> Self { + // create arrays + let columns = self + .schema + .iter() + .map(|(_influxdb_column_type, field)| match field.data_type() { + DataType::Int64 => Arc::new(Int64Array::from(vec![1000, 10, 70])) as ArrayRef, + DataType::UInt64 => Arc::new(UInt64Array::from(vec![1000, 10, 70])) as ArrayRef, + DataType::Utf8 => match field.name().as_str() { + "tag1" => Arc::new(StringArray::from(vec!["WA", "VT", "UT"])) as ArrayRef, + "tag2" => Arc::new(StringArray::from(vec!["SC", "NC", "RI"])) as ArrayRef, + _ => Arc::new(StringArray::from(vec!["TX", "PR", "OR"])) as ArrayRef, + }, + DataType::Timestamp(TimeUnit::Nanosecond, tz) => Arc::new( + TimestampNanosecondArray::from(vec![8000, 10000, 20000]) + .with_timezone_opt(tz.clone()), + ) as ArrayRef, + DataType::Dictionary(key, value) + if key.as_ref() == &DataType::Int32 && value.as_ref() == &DataType::Utf8 => + { + match field.name().as_str() { + "tag1" => Arc::new( + vec!["WA", "VT", "UT"] + .into_iter() + .collect::>(), + ) as ArrayRef, + "tag2" => Arc::new( + vec!["SC", "NC", "RI"] + .into_iter() + .collect::>(), + ) as ArrayRef, + _ => Arc::new( + vec!["TX", "PR", "OR"] + .into_iter() + .collect::>(), + ) as ArrayRef, + } + } + _ => unimplemented!( + "Unimplemented data type for test database: {:?}", + field.data_type() + ), + }) + .collect::>(); + + let batch = + RecordBatch::try_new(self.schema.as_arrow(), columns).expect("made record batch"); + + self.push_record_batch(batch); + self + } + + /// Prepares this chunk to return a specific record batch with four + /// rows of non null data that look like, duplicates within + /// "+------+------+-----------+-------------------------------+", + /// "| tag1 | tag2 | field_int | time |", + /// "+------+------+-----------+-------------------------------+", + /// "| WA | SC | 1000 | 1970-01-01 00:00:00.000028 |", + /// "| VT | NC | 10 | 1970-01-01 00:00:00.000210 |", (1) + /// "| UT | RI | 70 | 1970-01-01 00:00:00.000220 |", + /// "| VT | NC | 50 | 1970-01-01 00:00:00.000210 |", // duplicate of (1) + /// "+------+------+-----------+-------------------------------+", + /// Stats(min, max) : tag1(UT, WA), tag2(RI, SC), time(28000, 220000) + pub fn with_four_rows_of_data(mut self) -> Self { + // create arrays + let columns = self + .schema + .iter() + .map(|(_influxdb_column_type, field)| match field.data_type() { + DataType::Int64 => Arc::new(Int64Array::from(vec![1000, 10, 70, 50])) as ArrayRef, + DataType::Utf8 => match field.name().as_str() { + "tag1" => Arc::new(StringArray::from(vec!["WA", "VT", "UT", "VT"])) as ArrayRef, + "tag2" => Arc::new(StringArray::from(vec!["SC", "NC", "RI", "NC"])) as ArrayRef, + _ => Arc::new(StringArray::from(vec!["TX", "PR", "OR", "AL"])) as ArrayRef, + }, + DataType::Timestamp(TimeUnit::Nanosecond, tz) => Arc::new( + TimestampNanosecondArray::from(vec![28000, 210000, 220000, 210000]) + .with_timezone_opt(tz.clone()), + ) as ArrayRef, + DataType::Dictionary(key, value) + if key.as_ref() == &DataType::Int32 && value.as_ref() == &DataType::Utf8 => + { + match field.name().as_str() { + "tag1" => Arc::new( + vec!["WA", "VT", "UT", "VT"] + .into_iter() + .collect::>(), + ) as ArrayRef, + "tag2" => Arc::new( + vec!["SC", "NC", "RI", "NC"] + .into_iter() + .collect::>(), + ) as ArrayRef, + _ => Arc::new( + vec!["TX", "PR", "OR", "AL"] + .into_iter() + .collect::>(), + ) as ArrayRef, + } + } + _ => unimplemented!( + "Unimplemented data type for test database: {:?}", + field.data_type() + ), + }) + .collect::>(); + + let batch = + RecordBatch::try_new(self.schema.as_arrow(), columns).expect("made record batch"); + + self.push_record_batch(batch); + self + } + + /// Prepares this chunk to return a specific record batch with five + /// rows of non null data that look like, no duplicates within + /// "+------+------+-----------+-------------------------------+", + /// "| tag1 | tag2 | field_int | time |", + /// "+------+------+-----------+-------------------------------+", + /// "| MT | CT | 1000 | 1970-01-01 00:00:00.000001 |", + /// "| MT | AL | 10 | 1970-01-01 00:00:00.000007 |", + /// "| CT | CT | 70 | 1970-01-01 00:00:00.000000100 |", + /// "| AL | MA | 100 | 1970-01-01 00:00:00.000000050 |", + /// "| MT | AL | 5 | 1970-01-01 00:00:00.000005 |", + /// "+------+------+-----------+-------------------------------+", + /// Stats(min, max) : tag1(AL, MT), tag2(AL, MA), time(5, 7000) + pub fn with_five_rows_of_data(mut self) -> Self { + // create arrays + let columns = self + .schema + .iter() + .map(|(_influxdb_column_type, field)| match field.data_type() { + DataType::Int64 => { + Arc::new(Int64Array::from(vec![1000, 10, 70, 100, 5])) as ArrayRef + } + DataType::Utf8 => { + match field.name().as_str() { + "tag1" => Arc::new(StringArray::from(vec!["MT", "MT", "CT", "AL", "MT"])) + as ArrayRef, + "tag2" => Arc::new(StringArray::from(vec!["CT", "AL", "CT", "MA", "AL"])) + as ArrayRef, + _ => Arc::new(StringArray::from(vec!["CT", "MT", "AL", "AL", "MT"])) + as ArrayRef, + } + } + DataType::Timestamp(TimeUnit::Nanosecond, tz) => Arc::new( + TimestampNanosecondArray::from(vec![1000, 7000, 100, 50, 5000]) + .with_timezone_opt(tz.clone()), + ) as ArrayRef, + DataType::Dictionary(key, value) + if key.as_ref() == &DataType::Int32 && value.as_ref() == &DataType::Utf8 => + { + match field.name().as_str() { + "tag1" => Arc::new( + vec!["MT", "MT", "CT", "AL", "MT"] + .into_iter() + .collect::>(), + ) as ArrayRef, + "tag2" => Arc::new( + vec!["CT", "AL", "CT", "MA", "AL"] + .into_iter() + .collect::>(), + ) as ArrayRef, + _ => Arc::new( + vec!["CT", "MT", "AL", "AL", "MT"] + .into_iter() + .collect::>(), + ) as ArrayRef, + } + } + _ => unimplemented!( + "Unimplemented data type for test database: {:?}", + field.data_type() + ), + }) + .collect::>(); + + let batch = + RecordBatch::try_new(self.schema.as_arrow(), columns).expect("made record batch"); + + self.push_record_batch(batch); + self + } + + /// Prepares this chunk to return a specific record batch with ten + /// rows of non null data that look like, duplicates within + /// "+------+------+-----------+-------------------------------+", + /// "| tag1 | tag2 | field_int | time |", + /// "+------+------+-----------+-------------------------------+", + /// "| MT | CT | 1000 | 1970-01-01 00:00:00.000001 |", + /// "| MT | AL | 10 | 1970-01-01 00:00:00.000007 |", (1) + /// "| CT | CT | 70 | 1970-01-01 00:00:00.000000100 |", + /// "| AL | MA | 100 | 1970-01-01 00:00:00.000000050 |", (2) + /// "| MT | AL | 5 | 1970-01-01 00:00:00.000005 |", (3) + /// "| MT | CT | 1000 | 1970-01-01 00:00:00.000002 |", + /// "| MT | AL | 20 | 1970-01-01 00:00:00.000007 |", // Duplicate with (1) + /// "| CT | CT | 70 | 1970-01-01 00:00:00.000000500 |", + /// "| AL | MA | 10 | 1970-01-01 00:00:00.000000050 |", // Duplicate with (2) + /// "| MT | AL | 30 | 1970-01-01 00:00:00.000005 |", // Duplicate with (3) + /// "+------+------+-----------+-------------------------------+", + /// Stats(min, max) : tag1(AL, MT), tag2(AL, MA), time(5, 7000) + pub fn with_ten_rows_of_data_some_duplicates(mut self) -> Self { + // create arrays + let columns = self + .schema + .iter() + .map(|(_influxdb_column_type, field)| match field.data_type() { + DataType::Int64 => Arc::new(Int64Array::from(vec![ + 1000, 10, 70, 100, 5, 1000, 20, 70, 10, 30, + ])) as ArrayRef, + DataType::Utf8 => match field.name().as_str() { + "tag1" => Arc::new(StringArray::from(vec![ + "MT", "MT", "CT", "AL", "MT", "MT", "MT", "CT", "AL", "MT", + ])) as ArrayRef, + "tag2" => Arc::new(StringArray::from(vec![ + "CT", "AL", "CT", "MA", "AL", "CT", "AL", "CT", "MA", "AL", + ])) as ArrayRef, + _ => Arc::new(StringArray::from(vec![ + "CT", "MT", "AL", "AL", "MT", "CT", "MT", "AL", "AL", "MT", + ])) as ArrayRef, + }, + DataType::Timestamp(TimeUnit::Nanosecond, tz) => Arc::new( + TimestampNanosecondArray::from(vec![ + 1000, 7000, 100, 50, 5, 2000, 7000, 500, 50, 5, + ]) + .with_timezone_opt(tz.clone()), + ) as ArrayRef, + DataType::Dictionary(key, value) + if key.as_ref() == &DataType::Int32 && value.as_ref() == &DataType::Utf8 => + { + match field.name().as_str() { + "tag1" => Arc::new( + vec!["MT", "MT", "CT", "AL", "MT", "MT", "MT", "CT", "AL", "MT"] + .into_iter() + .collect::>(), + ) as ArrayRef, + "tag2" => Arc::new( + vec!["CT", "AL", "CT", "MA", "AL", "CT", "AL", "CT", "MA", "AL"] + .into_iter() + .collect::>(), + ) as ArrayRef, + _ => Arc::new( + vec!["CT", "MT", "AL", "AL", "MT", "CT", "MT", "AL", "AL", "MT"] + .into_iter() + .collect::>(), + ) as ArrayRef, + } + } + _ => unimplemented!( + "Unimplemented data type for test database: {:?}", + field.data_type() + ), + }) + .collect::>(); + + let batch = + RecordBatch::try_new(self.schema.as_arrow(), columns).expect("made record batch"); + + self.push_record_batch(batch); + self + } + + /// Set the sort key for this chunk + pub fn with_sort_key(self, sort_key: SortKey) -> Self { + Self { + sort_key: Some(sort_key), + ..self + } + } + + pub fn table_name(&self) -> &str { + &self.table_name + } +} + +impl fmt::Display for TestChunk { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.table_name()) + } +} + +impl QueryChunk for TestChunk { + fn stats(&self) -> Arc { + self.check_error().unwrap(); + + Arc::new(DataFusionStatistics { + num_rows: option_to_precision(self.num_rows), + total_byte_size: Precision::Absent, + column_statistics: self + .schema + .inner() + .fields() + .iter() + .map(|f| self.column_stats.get(f.name()).cloned().unwrap_or_default()) + .collect(), + }) + } + + fn schema(&self) -> &Schema { + &self.schema + } + + fn partition_id(&self) -> &TransitionPartitionId { + &self.partition_id + } + + fn sort_key(&self) -> Option<&SortKey> { + self.sort_key.as_ref() + } + + fn id(&self) -> ChunkId { + self.id + } + + fn may_contain_pk_duplicates(&self) -> bool { + self.may_contain_pk_duplicates + } + + fn data(&self) -> QueryChunkData { + self.check_error().unwrap(); + + match &self.table_data { + TestChunkData::RecordBatches(batches) => { + QueryChunkData::in_mem(batches.clone(), Arc::clone(self.schema.inner())) + } + TestChunkData::Parquet(input) => QueryChunkData::Parquet(input.clone()), + } + } + + fn chunk_type(&self) -> &str { + "Test Chunk" + } + + fn order(&self) -> ChunkOrder { + self.order + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +/// Return the raw data from the list of chunks +pub async fn raw_data(chunks: &[Arc]) -> Vec { + let ctx = IOxSessionContext::with_testing(); + let mut batches = vec![]; + for c in chunks { + batches.append(&mut c.data().read_to_batches(c.schema(), ctx.inner()).await); + } + batches +} + +pub fn format_logical_plan(plan: &LogicalPlan) -> Vec { + format_lines(&plan.display_indent().to_string()) +} + +pub fn format_execution_plan(plan: &Arc) -> Vec { + format_lines(&displayable(plan.as_ref()).indent(false).to_string()) +} + +fn format_lines(s: &str) -> Vec { + s.trim() + .split('\n') + .map(|s| { + // Always add a leading space to ensure tha all lines in the YAML insta snapshots are quoted, otherwise the + // alignment gets messed up and the snapshot would be hard to read. + format!(" {s}") + }) + .collect() +} diff --git a/iox_query/src/util.rs b/iox_query/src/util.rs new file mode 100644 index 0000000..7cd92a4 --- /dev/null +++ b/iox_query/src/util.rs @@ -0,0 +1,325 @@ +//! This module contains DataFusion utility functions and helpers + +use std::{ + cmp::{max, min}, + sync::Arc, +}; + +use arrow::{ + array::TimestampNanosecondArray, + compute::SortOptions, + datatypes::{Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}, + record_batch::RecordBatch, +}; + +use data_types::TimestampMinMax; +use datafusion::common::stats::Precision; +use datafusion::physical_expr::{analyze, AnalysisContext, ExprBoundaries}; +use datafusion::{ + self, + common::ToDFSchema, + datasource::{provider_as_source, MemTable}, + error::DataFusionError, + execution::context::ExecutionProps, + logical_expr::{interval_arithmetic::Interval, LogicalPlan, LogicalPlanBuilder}, + optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}, + physical_expr::create_physical_expr, + physical_plan::{ + expressions::{col as physical_col, PhysicalSortExpr}, + PhysicalExpr, + }, + prelude::{Column, Expr}, +}; + +use itertools::Itertools; +use observability_deps::tracing::trace; +use schema::{sort::SortKey, TIME_COLUMN_NAME}; +use snafu::{ensure, OptionExt, ResultExt, Snafu}; + +#[derive(Debug, Snafu)] +#[allow(missing_copy_implementations, missing_docs)] +pub enum Error { + #[snafu(display("The Record batch is empty"))] + EmptyBatch, + + #[snafu(display("Error while searching Time column in a Record Batch"))] + TimeColumn { source: arrow::error::ArrowError }, + + #[snafu(display("Error while casting Timenanosecond on Time column"))] + TimeCasting, + + #[snafu(display("Time column does not have value"))] + TimeValue, + + #[snafu(display("Time column is null"))] + TimeNull, +} + +/// A specialized `Error` +pub type Result = std::result::Result; + +/// Create a logical plan that produces the record batch +pub fn make_scan_plan(batch: RecordBatch) -> std::result::Result { + let schema = batch.schema(); + let partitions = vec![vec![batch]]; + let projection = None; // scan all columns + + let table = MemTable::try_new(schema, partitions)?; + + let source = provider_as_source(Arc::new(table)); + + LogicalPlanBuilder::scan("memtable", source, projection)?.build() +} + +pub fn logical_sort_key_exprs(sort_key: &SortKey) -> Vec { + sort_key + .iter() + .map(|(key, options)| { + let expr = Expr::Column(Column::from_name(key.as_ref())); + expr.sort(!options.descending, options.nulls_first) + }) + .collect() +} + +pub fn arrow_sort_key_exprs( + sort_key: &SortKey, + input_schema: &ArrowSchema, +) -> Vec { + sort_key + .iter() + .flat_map(|(key, options)| { + // Skip over missing columns + let expr = physical_col(key, input_schema).ok()?; + Some(PhysicalSortExpr { + expr, + options: SortOptions { + descending: options.descending, + nulls_first: options.nulls_first, + }, + }) + }) + .collect() +} + +/// Build a datafusion physical expression from a logical one +pub fn df_physical_expr( + schema: ArrowSchemaRef, + expr: Expr, +) -> std::result::Result, DataFusionError> { + let df_schema = Arc::clone(&schema).to_dfschema_ref()?; + + let props = ExecutionProps::new(); + let simplifier = + ExprSimplifier::new(SimplifyContext::new(&props).with_schema(Arc::clone(&df_schema))); + + // apply type coercion here to ensure types match + trace!(%df_schema, "input schema"); + let expr = simplifier.coerce(expr, Arc::clone(&df_schema))?; + trace!(%expr, "coerced logical expression"); + + create_physical_expr(&expr, df_schema.as_ref(), schema.as_ref(), &props) +} + +/// Return min and max for column `time` of the given set of record batches by +/// performing an `O(n)` scan of all provided batches. +pub fn compute_timenanosecond_min_max<'a, I>(batches: I) -> Result +where + I: IntoIterator, +{ + let mut min_time = i64::MAX; + let mut max_time = i64::MIN; + for batch in batches { + let (mi, ma) = compute_timenanosecond_min_max_for_one_record_batch(batch)?; + min_time = min(min_time, mi); + max_time = max(max_time, ma); + } + Ok(TimestampMinMax { + min: min_time, + max: max_time, + }) +} + +/// Return min and max for column `time` in the given record batch by performing +/// an `O(n)` scan of `batch`. +pub fn compute_timenanosecond_min_max_for_one_record_batch( + batch: &RecordBatch, +) -> Result<(i64, i64)> { + ensure!(batch.num_columns() > 0, EmptyBatchSnafu); + + let index = batch + .schema() + .index_of(TIME_COLUMN_NAME) + .context(TimeColumnSnafu {})?; + + let time_col = batch + .column(index) + .as_any() + .downcast_ref::() + .context(TimeCastingSnafu {})?; + + let (min, max) = match time_col.iter().minmax() { + itertools::MinMaxResult::NoElements => return Err(Error::TimeValue), + itertools::MinMaxResult::OneElement(val) => { + let val = val.context(TimeNullSnafu)?; + (val, val) + } + itertools::MinMaxResult::MinMax(min, max) => { + (min.context(TimeNullSnafu)?, max.context(TimeNullSnafu)?) + } + }; + + Ok((min, max)) +} + +/// Determine the possible maximum range for each of the fields in a +/// ['ArrowSchema'] once the ['Expr'] has been applied. The returned +/// Vec includes an Interval for every field in the schema in the same +/// order. Any fileds that are not constrained by the expression will +/// have an unbounded interval. +pub fn calculate_field_intervals( + schema: ArrowSchemaRef, + expr: Expr, +) -> Result, DataFusionError> { + // make unknown boundaries for each column + // TODO use upstream code when https://github.com/apache/arrow-datafusion/pull/8377 is merged + let fields = schema.fields(); + let boundaries = fields + .iter() + .enumerate() + .map(|(i, field)| { + let column = datafusion::physical_expr::expressions::Column::new(field.name(), i); + let interval = Interval::make_unbounded(field.data_type())?; + Ok(ExprBoundaries { + column, + interval, + distinct_count: Precision::Absent, + }) + }) + .collect::, DataFusionError>>()?; + + let context = AnalysisContext::new(boundaries); + let analysis_result = analyze( + &df_physical_expr(Arc::clone(&schema), expr)?, + context, + &schema, + )?; + + let intervals = analysis_result + .boundaries + .into_iter() + .map(|b| b.interval) + .collect::>(); + + Ok(intervals) +} + +/// Determine the possible maximum range for the named field in the +/// ['ArrowSchema'] once the ['Expr'] has been applied. +pub fn calculate_field_interval( + schema: ArrowSchemaRef, + expr: Expr, + name: &str, +) -> Result { + let idx = schema.index_of(name)?; + let mut intervals = calculate_field_intervals(Arc::clone(&schema), expr)?; + Ok(intervals.swap_remove(idx)) +} + +#[cfg(test)] +mod tests { + use datafusion::common::rounding::next_down; + use datafusion::common::ScalarValue; + use datafusion::logical_expr::{col, lit}; + use schema::{builder::SchemaBuilder, InfluxFieldType, TIME_DATA_TIMEZONE}; + + use super::*; + + fn time_interval(lower: Option, upper: Option) -> Interval { + let lower = ScalarValue::TimestampNanosecond(lower, TIME_DATA_TIMEZONE()); + let upper = ScalarValue::TimestampNanosecond(upper, TIME_DATA_TIMEZONE()); + Interval::try_new(lower, upper).unwrap() + } + + fn f64_interval(lower: Option, upper: Option) -> Interval { + let lower = ScalarValue::Float64(lower); + let upper = ScalarValue::Float64(upper); + Interval::try_new(lower, upper).unwrap() + } + + #[test] + fn test_calculate_field_intervals() { + let schema = SchemaBuilder::new() + .timestamp() + .influx_field("a", InfluxFieldType::Float) + .build() + .unwrap() + .as_arrow(); + let expr = col("time") + .gt_eq(lit("2020-01-01T00:00:00Z")) + .and(col("time").lt(lit("2020-01-02T00:00:00Z"))) + .and(col("a").gt_eq(lit(1000000.0))) + .and(col("a").lt(lit(2000000.0))); + let intervals = calculate_field_intervals(schema, expr).unwrap(); + // 2020-01-01T00:00:00Z == 1577836800000000000 + // 2020-01-02T00:00:00Z == 1577923200000000000 + assert_eq!( + vec![ + time_interval(Some(1577836800000000000), Some(1577923200000000000i64 - 1),), + f64_interval(Some(1000000.0), Some(next_down(2000000.0))) + ], + intervals + ); + } + + #[test] + fn test_calculate_field_intervals_no_constraints() { + let schema = SchemaBuilder::new() + .timestamp() + .influx_field("a", InfluxFieldType::Float) + .build() + .unwrap() + .as_arrow(); + // must be a predicate (boolean expression) + let expr = lit("test").eq(lit("foo")); + let intervals = calculate_field_intervals(schema, expr).unwrap(); + assert_eq!( + vec![time_interval(None, None), f64_interval(None, None)], + intervals + ); + } + + #[test] + fn test_calculate_field_interval() { + let schema = SchemaBuilder::new() + .timestamp() + .influx_field("a", InfluxFieldType::Float) + .build() + .unwrap() + .as_arrow(); + let expr = col("time") + .gt_eq(lit("2020-01-01T00:00:00Z")) + .and(col("time").lt(lit("2020-01-02T00:00:00Z"))) + .and(col("a").gt_eq(lit(1000000.0))) + .and(col("a").lt(lit(2000000.0))); + + // Note + // 2020-01-01T00:00:00Z == 1577836800000000000 + // 2020-01-02T00:00:00Z == 1577923200000000000 + let interval = calculate_field_interval(Arc::clone(&schema), expr.clone(), "time").unwrap(); + assert_eq!( + time_interval(Some(1577836800000000000), Some(1577923200000000000 - 1),), + interval + ); + + let interval = calculate_field_interval(Arc::clone(&schema), expr.clone(), "a").unwrap(); + assert_eq!( + f64_interval(Some(1000000.0), Some(next_down(2000000.0))), + interval + ); + + assert_eq!( + "Arrow error: Schema error: Unable to get field named \"b\". Valid fields: [\"time\", \"a\"]", + calculate_field_interval(Arc::clone(&schema), expr.clone(), "b").unwrap_err().to_string(), + ); + } +} diff --git a/iox_query_influxql/Cargo.toml b/iox_query_influxql/Cargo.toml new file mode 100644 index 0000000..0c11612 --- /dev/null +++ b/iox_query_influxql/Cargo.toml @@ -0,0 +1,34 @@ +[package] +name = "iox_query_influxql" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[lints] +workspace = true + +[dependencies] +arrow = { workspace = true } +chrono-tz = { version = "0.8" } +datafusion = { workspace = true } +datafusion_util = { path = "../datafusion_util" } +generated_types = { path = "../generated_types" } +influxdb_influxql_parser = { path = "../influxdb_influxql_parser" } +iox_query = { path = "../iox_query" } +itertools = "0.12.0" +observability_deps = { path = "../observability_deps" } +once_cell = "1" +predicate = { path = "../predicate" } +query_functions = { path = "../query_functions" } +regex = "1" +schema = { path = "../schema" } +serde_json = "1.0.111" +thiserror = "1.0" +workspace-hack = { version = "0.1", path = "../workspace-hack" } + +[dev-dependencies] # In alphabetical order +chrono = { version = "0.4", default-features = false } +test_helpers = { path = "../test_helpers" } +assert_matches = "1" +insta = { version = "1", features = ["yaml"] } diff --git a/iox_query_influxql/src/aggregate.rs b/iox_query_influxql/src/aggregate.rs new file mode 100644 index 0000000..badb279 --- /dev/null +++ b/iox_query_influxql/src/aggregate.rs @@ -0,0 +1,24 @@ +//! User defined aggregate functions implementing influxQL features. + +use datafusion::logical_expr::{ + AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, StateTypeFunction, +}; +use once_cell::sync::Lazy; +use std::sync::Arc; + +mod percentile; + +/// Definition of the `PERCENTILE` user-defined aggregate function. +pub(crate) static PERCENTILE: Lazy> = Lazy::new(|| { + let return_type: ReturnTypeFunction = Arc::new(percentile::return_type); + let accumulator: AccumulatorFactoryFunction = Arc::new(percentile::accumulator); + let state_type: StateTypeFunction = Arc::new(percentile::state_type); + + Arc::new(AggregateUDF::new( + percentile::NAME, + &percentile::SIGNATURE, + &return_type, + &accumulator, + &state_type, + )) +}); diff --git a/iox_query_influxql/src/aggregate/percentile.rs b/iox_query_influxql/src/aggregate/percentile.rs new file mode 100644 index 0000000..dda8659 --- /dev/null +++ b/iox_query_influxql/src/aggregate/percentile.rs @@ -0,0 +1,157 @@ +use crate::error; +use arrow::array::{as_list_array, Array, ArrayRef, Float64Array, Int64Array}; +use arrow::datatypes::{DataType, Field}; +use datafusion::common::{downcast_value, DataFusionError, Result, ScalarValue}; +use datafusion::logical_expr::{Accumulator, Signature, TypeSignature, Volatility}; +use once_cell::sync::Lazy; +use std::sync::Arc; + +/// The name of the percentile aggregate function. +pub(super) const NAME: &str = "percentile"; + +/// Valid signatures for the percentile aggregate function. +pub(super) static SIGNATURE: Lazy = Lazy::new(|| { + Signature::one_of( + crate::NUMERICS + .iter() + .flat_map(|dt| { + [ + TypeSignature::Exact(vec![dt.clone(), DataType::Int64]), + TypeSignature::Exact(vec![dt.clone(), DataType::Float64]), + ] + }) + .collect(), + Volatility::Immutable, + ) +}); + +/// Calculate the return type given the function signature. Percentile +/// always returns the same type as the input column. +pub(super) fn return_type(signature: &[DataType]) -> Result> { + Ok(Arc::new(signature[0].clone())) +} + +/// Create a new accumulator for the data type. +pub(super) fn accumulator(dt: &DataType) -> Result> { + Ok(Box::new(PercentileAccumulator::new(dt.clone()))) +} + +/// Calculate the intermediate merge state for the aggregator. +pub(super) fn state_type(dt: &DataType) -> Result>> { + Ok(Arc::new(vec![ + DataType::List(Arc::new(Field::new("item", dt.clone(), true))), + DataType::Float64, + ])) +} + +#[derive(Debug)] +struct PercentileAccumulator { + data_type: DataType, + data: Vec, + percentile: Option, +} + +impl PercentileAccumulator { + fn new(data_type: DataType) -> Self { + Self { + data_type, + data: vec![], + percentile: None, + } + } + + fn update(&mut self, array: ArrayRef) -> Result<()> { + let array = Arc::clone(&array); + assert_eq!(array.data_type(), &self.data_type); + + let nulls = array.nulls(); + let null_len = nulls.map_or(0, |nb| nb.null_count()); + self.data.reserve(array.len() - null_len); + for idx in 0..array.len() { + if nulls.map_or(true, |nb| nb.is_valid(idx)) { + self.data.push(ScalarValue::try_from_array(&array, idx)?) + } + } + Ok(()) + } + + fn set_percentile(&mut self, array: ArrayRef) -> Result<()> { + if self.percentile.is_none() && array.is_valid(0) { + self.percentile = match array.data_type() { + DataType::Int64 => Some(downcast_value!(array, Int64Array).value(0) as f64), + DataType::Float64 => Some(downcast_value!(array, Float64Array).value(0)), + dt => { + return error::internal(format!( + "invalid data type ({dt}) for PERCENTILE n argument" + )) + } + }; + } + Ok(()) + } +} + +impl Accumulator for PercentileAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + assert_eq!(values.len(), 2); + + self.set_percentile(Arc::clone(&values[1]))?; + self.update(Arc::clone(&values[0])) + } + + fn evaluate(&self) -> Result { + let idx = self + .percentile + .and_then(|n| percentile_idx(self.data.len(), n)); + if idx.is_none() { + return Ok(ScalarValue::Float64(None)); + } + + let array = ScalarValue::iter_to_array(self.data.clone())?; + let indices = arrow::compute::sort_to_indices(&array, None, None)?; + let array_idx = indices.value(idx.unwrap()); + ScalarValue::try_from_array(&array, array_idx as usize) + } + + fn size(&self) -> usize { + std::mem::size_of::>() + + std::mem::size_of::() + + ScalarValue::size_of_vec(&self.data) + } + + fn state(&self) -> Result> { + let arr = ScalarValue::new_list(&self.data, &self.data_type); + Ok(vec![ + ScalarValue::List(arr), + ScalarValue::Float64(self.percentile), + ]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + assert_eq!(states.len(), 2); + + self.set_percentile(Arc::clone(&states[1]))?; + + let array = Arc::clone(&states[0]); + let list_array = as_list_array(&array); + for idx in 0..list_array.len() { + self.update(list_array.value(idx))?; + } + Ok(()) + } +} + +/// Calculate the location in an ordered list of len items where the +/// location of the item at the given percentile would be found. +/// +/// This uses the same algorithm as the original influxdb implementation +/// of percentile as can be found in +/// . +fn percentile_idx(len: usize, percentile: f64) -> Option { + match TryInto::::try_into( + (((len as f64) * percentile / 100.0 + 0.5).floor() as isize) - 1, + ) { + Ok(idx) if idx < len => Some(idx), + _ => None, + } +} diff --git a/iox_query_influxql/src/error.rs b/iox_query_influxql/src/error.rs new file mode 100644 index 0000000..cc2dd6d --- /dev/null +++ b/iox_query_influxql/src/error.rs @@ -0,0 +1,71 @@ +use datafusion::common::Result; + +/// An error that was the result of an invalid InfluxQL query. +pub(crate) fn query(s: impl Into) -> Result { + Err(map::query(s)) +} + +/// An unexpected error whilst planning that represents a bug in IOx. +pub(crate) fn internal(s: impl Into) -> Result { + Err(map::internal(s)) +} + +/// The specified `feature` is not implemented. +pub(crate) fn not_implemented(feature: impl Into) -> Result { + Err(map::not_implemented(feature)) +} + +/// Functions that return a DataFusionError rather than a `Result` +/// making them convenient to use with functions like `map_err`. +pub(crate) mod map { + use datafusion::common::DataFusionError; + use influxdb_influxql_parser::time_range::ExprError; + use thiserror::Error; + + #[derive(Debug, Error)] + enum PlannerError { + /// An unexpected error that represents a bug in IOx. + /// + /// The message is prefixed with `InfluxQL internal error: `, + /// which may be used by clients to identify internal InfluxQL + /// errors. + #[error("InfluxQL internal error: {0}")] + Internal(String), + } + + /// An error that was the result of an invalid InfluxQL query. + pub(crate) fn query(s: impl Into) -> DataFusionError { + DataFusionError::Plan(s.into()) + } + + /// An unexpected error whilst planning that represents a bug in IOx. + pub(crate) fn internal(s: impl Into) -> DataFusionError { + DataFusionError::External(Box::new(PlannerError::Internal(s.into()))) + } + + /// The specified `feature` is not implemented. + pub(crate) fn not_implemented(feature: impl Into) -> DataFusionError { + DataFusionError::NotImplemented(feature.into()) + } + + /// Map an [`ExprError`] to a DataFusion error. + pub(crate) fn expr_error(err: ExprError) -> DataFusionError { + match err { + ExprError::Expression(s) => query(s), + ExprError::Internal(s) => internal(s), + } + } + + #[cfg(test)] + mod test { + use crate::error::map::PlannerError; + + #[test] + fn test_planner_error_display() { + // The InfluxQL internal error: + assert!(PlannerError::Internal("****".to_owned()) + .to_string() + .starts_with("InfluxQL internal error: ")) + } + } +} diff --git a/iox_query_influxql/src/frontend/mod.rs b/iox_query_influxql/src/frontend/mod.rs new file mode 100644 index 0000000..5e48085 --- /dev/null +++ b/iox_query_influxql/src/frontend/mod.rs @@ -0,0 +1 @@ +pub mod planner; diff --git a/iox_query_influxql/src/frontend/planner.rs b/iox_query_influxql/src/frontend/planner.rs new file mode 100644 index 0000000..f8f6ff0 --- /dev/null +++ b/iox_query_influxql/src/frontend/planner.rs @@ -0,0 +1,420 @@ +use arrow::datatypes::SchemaRef; +use datafusion::common::ParamValues; +use datafusion::physical_expr::execution_props::ExecutionProps; +use influxdb_influxql_parser::show_field_keys::ShowFieldKeysStatement; +use influxdb_influxql_parser::show_measurements::ShowMeasurementsStatement; +use influxdb_influxql_parser::show_tag_keys::ShowTagKeysStatement; +use influxdb_influxql_parser::show_tag_values::ShowTagValuesStatement; +use std::any::Any; +use std::collections::{HashMap, HashSet}; +use std::fmt; +use std::fmt::Debug; +use std::ops::Deref; +use std::sync::Arc; + +use crate::plan::{parse_regex, InfluxQLToLogicalPlan, SchemaProvider}; +use datafusion::datasource::provider_as_source; +use datafusion::execution::context::{SessionState, TaskContext}; +use datafusion::logical_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource}; +use datafusion::physical_expr::PhysicalSortExpr; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, Partitioning, SendableRecordBatchStream, +}; +use datafusion::{ + error::{DataFusionError, Result}, + physical_plan::ExecutionPlan, +}; +use influxdb_influxql_parser::common::MeasurementName; +use influxdb_influxql_parser::parse_statements; +use influxdb_influxql_parser::statement::Statement; +use influxdb_influxql_parser::visit::{Visitable, Visitor}; +use iox_query::exec::IOxSessionContext; +use observability_deps::tracing::debug; +use schema::Schema; + +struct ContextSchemaProvider<'a> { + state: &'a SessionState, + tables: HashMap, Schema)>, +} + +impl<'a> SchemaProvider for ContextSchemaProvider<'a> { + fn get_table_provider(&self, name: &str) -> Result> { + self.tables + .get(name) + .map(|(t, _)| Arc::clone(t)) + .ok_or_else(|| DataFusionError::Plan(format!("measurement does not exist: {name}"))) + } + + fn get_function_meta(&self, name: &str) -> Option> { + self.state.scalar_functions().get(name).cloned() + } + + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.state.aggregate_functions().get(name).cloned() + } + + fn table_names(&self) -> Vec<&'_ str> { + self.tables.keys().map(|k| k.as_str()).collect::>() + } + + fn table_exists(&self, name: &str) -> bool { + self.tables.contains_key(name) + } + + fn table_schema(&self, name: &str) -> Option { + self.tables.get(name).map(|(_, s)| s.clone()) + } + + fn execution_props(&self) -> &ExecutionProps { + self.state.execution_props() + } +} + +/// A physical operator that overrides the `schema` API, +/// to return an amended version owned by `SchemaExec`. The +/// principal use case is to add additional metadata to the schema. +struct SchemaExec { + input: Arc, + schema: SchemaRef, +} + +impl Debug for SchemaExec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.fmt_as(DisplayFormatType::Default, f) + } +} + +impl ExecutionPlan for SchemaExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn output_partitioning(&self) -> Partitioning { + self.input.output_partitioning() + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.input.output_ordering() + } + + fn children(&self) -> Vec> { + vec![Arc::clone(&self.input)] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + unimplemented!() + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + self.input.execute(partition, context) + } + + fn statistics(&self) -> Result { + Ok(datafusion::physical_plan::Statistics::new_unknown( + &self.schema(), + )) + } +} + +impl DisplayAs for SchemaExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "SchemaExec") + } + } + } +} + +/// Create plans for running InfluxQL queries against databases +#[derive(Debug, Default, Copy, Clone)] +pub struct InfluxQLQueryPlanner {} + +impl InfluxQLQueryPlanner { + pub fn new() -> Self { + Self::default() + } + + /// Plan an InfluxQL query against the catalogs registered with `ctx`, and return a + /// DataFusion physical execution plan that runs on the query executor. + pub async fn query( + &self, + query: &str, + params: impl Into + Send, + ctx: &IOxSessionContext, + ) -> Result> { + let ctx = ctx.child_ctx("InfluxQLQueryPlanner::query"); + debug!(text=%query, "planning InfluxQL query"); + + let statement = self.query_to_statement(query)?; + let logical_plan = self.statement_to_plan(statement, &ctx).await?; + // add params to plan only when they're non-empty + let logical_plan = match params.into() { + ParamValues::List(v) if !v.is_empty() => logical_plan.with_param_values(v)?, + ParamValues::Map(m) if !m.is_empty() => logical_plan.with_param_values(m)?, + _ => logical_plan, + }; + let input = ctx.create_physical_plan(&logical_plan).await?; + + // Merge schema-level metadata from the logical plan with the + // schema from the physical plan, as it is not propagated through the + // physical planning process. + let input_schema = input.schema(); + let mut md = input_schema.metadata().clone(); + md.extend(logical_plan.schema().metadata().clone()); + let schema = Arc::new(arrow::datatypes::Schema::new_with_metadata( + input_schema.fields().clone(), + md, + )); + + Ok(Arc::new(SchemaExec { input, schema })) + } + + async fn statement_to_plan( + &self, + statement: Statement, + ctx: &IOxSessionContext, + ) -> Result { + use std::collections::hash_map::Entry; + + let ctx = ctx.child_ctx("statement_to_plan"); + let session_cfg = ctx.inner().copied_config(); + let cfg = session_cfg.options(); + let schema = ctx + .inner() + .catalog(&cfg.catalog.default_catalog) + .ok_or_else(|| { + DataFusionError::Plan(format!( + "failed to resolve catalog: {}", + cfg.catalog.default_catalog + )) + })? + .schema(&cfg.catalog.default_schema) + .ok_or_else(|| { + DataFusionError::Plan(format!( + "failed to resolve schema: {}", + cfg.catalog.default_schema + )) + })?; + let names = schema.table_names(); + let query_tables = find_all_measurements(&statement, &names)?; + + let mut sp = ContextSchemaProvider { + state: &ctx.inner().state(), + tables: HashMap::with_capacity(query_tables.len()), + }; + + for table_name in &query_tables { + if let Entry::Vacant(v) = sp.tables.entry(table_name.to_string()) { + let mut ctx = ctx.child_ctx("get table schema"); + ctx.set_metadata("table", table_name.to_owned()); + + if let Some(table) = schema.table(table_name).await { + let schema = Schema::try_from(table.schema()) + .map_err(|err| { + DataFusionError::Internal(format!("unable to convert DataFusion schema for measurement {table_name} to IOx schema: {err}")) + })?; + v.insert((provider_as_source(table), schema)); + } + } + } + + let planner = InfluxQLToLogicalPlan::new(&sp, &ctx); + let logical_plan = planner.statement_to_plan(statement)?; + debug!(plan=%logical_plan.display_graphviz(), "logical plan"); + Ok(logical_plan) + } + + fn query_to_statement(&self, query: &str) -> Result { + let mut statements = + parse_statements(query).map_err(|e| DataFusionError::Plan(e.to_string()))?; + + if statements.len() != 1 { + return Err(DataFusionError::NotImplemented( + "The context currently only supports a single InfluxQL statement".to_string(), + )); + } + + Ok(statements.pop().unwrap()) + } +} + +fn find_all_measurements(stmt: &Statement, tables: &[String]) -> Result> { + struct Matcher<'a>(&'a mut HashSet, &'a [String]); + impl<'a> Visitor for Matcher<'a> { + type Error = DataFusionError; + + fn post_visit_measurement_name( + self, + mn: &MeasurementName, + ) -> std::result::Result { + match mn { + MeasurementName::Name(name) => { + let name = name.deref(); + if self.1.contains(name) { + self.0.insert(name.to_string()); + } + } + MeasurementName::Regex(re) => { + let re = parse_regex(re)?; + + self.1 + .iter() + .filter(|table| re.is_match(table)) + .for_each(|table| { + self.0.insert(table.into()); + }); + } + } + + Ok(self) + } + + fn post_visit_show_measurements_statement( + self, + sm: &ShowMeasurementsStatement, + ) -> Result { + if sm.with_measurement.is_none() { + self.0.extend(self.1.iter().cloned()); + } + + Ok(self) + } + + fn post_visit_show_field_keys_statement( + self, + sfk: &ShowFieldKeysStatement, + ) -> Result { + if sfk.from.is_none() { + self.0.extend(self.1.iter().cloned()); + } + + Ok(self) + } + + fn post_visit_show_tag_values_statement( + self, + stv: &ShowTagValuesStatement, + ) -> Result { + if stv.from.is_none() { + self.0.extend(self.1.iter().cloned()); + } + + Ok(self) + } + + fn post_visit_show_tag_keys_statement( + self, + stk: &ShowTagKeysStatement, + ) -> std::result::Result { + if stk.from.is_none() { + self.0.extend(self.1.iter().cloned()); + } + + Ok(self) + } + } + + let mut m = HashSet::new(); + let vis = Matcher(&mut m, tables); + stmt.accept(vis)?; + + Ok(m) +} + +#[cfg(test)] +mod test { + use super::*; + use itertools::Itertools; + use test_helpers::assert_error; + + #[test] + fn test_query_to_statement() { + let p = InfluxQLQueryPlanner::new(); + + // succeeds for a single statement + let _ = p.query_to_statement("SELECT foo FROM bar").unwrap(); + + // Fallible + + assert_error!( + p.query_to_statement("SELECT foo FROM bar; SELECT bar FROM foo"), + DataFusionError::NotImplemented(ref s) if s == "The context currently only supports a single InfluxQL statement" + ); + } + + #[test] + fn test_find_all_measurements() { + fn find(q: &str) -> Vec { + let p = InfluxQLQueryPlanner::new(); + let s = p.query_to_statement(q).unwrap(); + let tables = vec!["foo".into(), "bar".into(), "foobar".into()]; + let res = find_all_measurements(&s, &tables).unwrap(); + res.into_iter().sorted().collect() + } + + assert_eq!(find("SELECT * FROM foo"), vec!["foo"]); + assert_eq!(find("SELECT * FROM foo, foo"), vec!["foo"]); + assert_eq!(find("SELECT * FROM foo, bar"), vec!["bar", "foo"]); + assert_eq!(find("SELECT * FROM foo, none"), vec!["foo"]); + assert_eq!(find("SELECT * FROM /^foo/"), vec!["foo", "foobar"]); + assert_eq!(find("SELECT * FROM foo, /^bar/"), vec!["bar", "foo"]); + assert_eq!(find("SELECT * FROM //"), vec!["bar", "foo", "foobar"]); + + // Find all measurements in subqueries + assert_eq!( + find("SELECT * FROM foo, (SELECT * FROM bar)"), + vec!["bar", "foo"] + ); + assert_eq!( + find("SELECT * FROM foo, (SELECT * FROM /bar/)"), + vec!["bar", "foo", "foobar"] + ); + + // Find all measurements in `SHOW MEASUREMENTS` + assert_eq!(find("SHOW MEASUREMENTS"), vec!["bar", "foo", "foobar"]); + assert_eq!( + find("SHOW MEASUREMENTS WITH MEASUREMENT = foo"), + vec!["foo"] + ); + assert_eq!( + find("SHOW MEASUREMENTS WITH MEASUREMENT =~ /^foo/"), + vec!["foo", "foobar"] + ); + + // Find all measurements in `SHOW FIELD KEYS` + assert_eq!(find("SHOW FIELD KEYS"), vec!["bar", "foo", "foobar"]); + assert_eq!(find("SHOW FIELD KEYS FROM /^foo/"), vec!["foo", "foobar"]); + + // Find all measurements in `SHOW TAG VALUES` + assert_eq!( + find("SHOW TAG VALUES WITH KEY = \"k\""), + vec!["bar", "foo", "foobar"] + ); + assert_eq!( + find("SHOW TAG VALUES FROM /^foo/ WITH KEY = \"k\""), + vec!["foo", "foobar"] + ); + + // Find all measurements in `SHOW TAG KEYS` + assert_eq!(find("SHOW TAG KEYS"), vec!["bar", "foo", "foobar"]); + assert_eq!(find("SHOW TAG KEYS FROM /^foo/"), vec!["foo", "foobar"]); + + // Finds no measurements + assert!(find("SELECT * FROM none").is_empty()); + assert!(find("SELECT * FROM (SELECT * FROM none)").is_empty()); + assert!(find("SELECT * FROM /^l/").is_empty()); + assert!(find("SELECT * FROM (SELECT * FROM /^l/)").is_empty()); + } +} diff --git a/iox_query_influxql/src/lib.rs b/iox_query_influxql/src/lib.rs new file mode 100644 index 0000000..e236959 --- /dev/null +++ b/iox_query_influxql/src/lib.rs @@ -0,0 +1,28 @@ +//! Contains the IOx InfluxQL query planner +#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)] +#![warn( + missing_debug_implementations, + clippy::explicit_iter_loop, + clippy::use_self, + clippy::clone_on_ref_ptr, + // See https://github.com/influxdata/influxdb_iox/pull/1671 + clippy::future_not_send, + clippy::todo, + clippy::dbg_macro, + unused_crate_dependencies +)] + +use arrow::datatypes::DataType; + +// Workaround for "unused crate" lint false positives. +use workspace_hack as _; + +mod aggregate; +mod error; +pub mod frontend; +pub mod plan; +mod window; + +/// A list of the numeric types supported by InfluxQL that can be be used +/// as input to user-defined functions. +static NUMERICS: &[DataType] = &[DataType::Int64, DataType::UInt64, DataType::Float64]; diff --git a/iox_query_influxql/src/plan/expr_type_evaluator.rs b/iox_query_influxql/src/plan/expr_type_evaluator.rs new file mode 100644 index 0000000..e2103ba --- /dev/null +++ b/iox_query_influxql/src/plan/expr_type_evaluator.rs @@ -0,0 +1,808 @@ +use crate::error; +use crate::plan::field::field_by_name; +use crate::plan::field_mapper::map_type; +use crate::plan::ir::DataSource; +use crate::plan::var_ref::influx_type_to_var_ref_data_type; +use crate::plan::SchemaProvider; +use datafusion::common::Result; +use influxdb_influxql_parser::expression::{ + Binary, BinaryOperator, Call, Expr, VarRef, VarRefDataType, +}; +use influxdb_influxql_parser::literal::Literal; +use influxdb_influxql_parser::select::Dimension; +use itertools::Itertools; + +/// Evaluate the type of the specified expression. +/// +/// Derived from [Go implementation](https://github.com/influxdata/influxql/blob/1ba470371ec093d57a726b143fe6ccbacf1b452b/ast.go#L4796-L4797). +pub(super) struct TypeEvaluator<'a> { + s: &'a dyn SchemaProvider, + from: &'a [DataSource], + /// Setting this to `true` will ensure scalar functions return errors for invalid data types. + /// The default is false, to ensure compatibility with InfluxQL OG. + call_type_is_strict: bool, +} + +impl<'a> TypeEvaluator<'a> { + /// Create a `TypeEvaluator` with behavior compatible with InfluxQL OG. + /// + /// This behavior includes limited evaluation of [`Call`] expressions, as described + /// by [`TypeEvaluator::eval_scalar`]. + pub(super) fn new(s: &'a dyn SchemaProvider, from: &'a [DataSource]) -> Self { + Self { + from, + s, + call_type_is_strict: false, + } + } + + /// Create a `TypeEvaluator` with strict behavior. + /// + /// This behavior includes strict evaluation of [`Call`] expressions, that are + /// not compatible with InfluxQL OG, but may be enabled in the future to improve + /// the user experience. + /// + /// # NOTE + /// + /// This behaviour is unused in production, but may be enabled to improve the + /// user experience of InfluxQL. + #[cfg(test)] + fn new_strict(s: &'a dyn SchemaProvider, from: &'a [DataSource]) -> Self { + Self { + from, + s, + call_type_is_strict: true, + } + } + + pub(super) fn eval_type(&self, expr: &Expr) -> Result> { + Ok(match expr { + Expr::VarRef(v) => self.eval_var_ref(v)?, + Expr::Call(v) => self.eval_call(v)?, + Expr::Binary(expr) => self.eval_binary_expr_type(expr)?, + Expr::Nested(expr) => self.eval_type(expr)?, + Expr::Literal(Literal::Float(_)) => Some(VarRefDataType::Float), + Expr::Literal(Literal::Unsigned(_)) => Some(VarRefDataType::Unsigned), + Expr::Literal(Literal::Integer(_)) => Some(VarRefDataType::Integer), + Expr::Literal(Literal::String(_)) => Some(VarRefDataType::String), + Expr::Literal(Literal::Boolean(_)) => Some(VarRefDataType::Boolean), + // Remaining patterns are not valid field types + Expr::BindParameter(_) + | Expr::Distinct(_) + | Expr::Wildcard(_) + | Expr::Literal(Literal::Duration(_)) + | Expr::Literal(Literal::Regex(_)) + | Expr::Literal(Literal::Timestamp(_)) => None, + }) + } + + fn eval_binary_expr_type(&self, expr: &Binary) -> Result> { + let (lhs, op, rhs) = ( + self.eval_type(&expr.lhs)?, + expr.op, + self.eval_type(&expr.rhs)?, + ); + + // Deviation from InfluxQL OG, which fails if one operand is unsigned and the other is + // an integer. This will let some additional queries succeed that would otherwise have + // failed. + // + // In this case, we will let DataFusion handle automatic coercion, rather than fail. + // + // See: https://github.com/influxdata/influxql/blob/802555d6b3a35cd464a6d8afa2a6511002cf3c2c/ast.go#L4729-L4730 + + match (lhs, rhs) { + (Some(dt), None) | (None, Some(dt)) => Ok(Some(dt)), + (None, None) => Ok(None), + (Some(lhs), Some(rhs)) => { + Ok(Some(binary_data_type(lhs, op, rhs).ok_or_else(|| { + error::map::query(format!( + "incompatible operands for operator {op}: {lhs} and {rhs}" + )) + })?)) + } + } + } + + /// Returns the type for the specified [`VarRef`]. + /// + /// This function assumes that the expression has already been reduced. + pub(super) fn eval_var_ref(&self, expr: &VarRef) -> Result> { + Ok(match expr.data_type { + Some(dt) + if matches!( + dt, + VarRefDataType::Integer + | VarRefDataType::Unsigned + | VarRefDataType::Float + | VarRefDataType::String + | VarRefDataType::Boolean + | VarRefDataType::Tag + ) => + { + Some(dt) + } + _ => { + let mut data_type: Option = None; + for tr in self.from { + match tr { + DataSource::Table(name) => match ( + data_type, + map_type(self.s, name.as_str(), expr.name.as_str()), + ) { + (Some(existing), Some(res)) => { + if res < existing { + data_type = Some(res) + } + } + (None, Some(res)) => data_type = Some(res), + _ => continue, + }, + DataSource::Subquery(select) => { + // find the field by name + if let Some(field) = field_by_name(&select.fields, expr.name.as_str()) { + match (data_type, influx_type_to_var_ref_data_type(field.data_type)) + { + (Some(existing), Some(res)) => { + if res < existing { + data_type = Some(res) + } + } + (None, Some(res)) => data_type = Some(res), + _ => {} + } + }; + + if data_type.is_none() { + if let Some(group_by) = &select.group_by { + if group_by.iter().any(|dim| { + matches!(dim, Dimension::VarRef(VarRef { name, ..}) if name.as_str() == expr.name.as_str()) + }) { + data_type = Some(VarRefDataType::Tag); + } + } + } + } + } + } + + data_type + } + }) + } + + /// Evaluate the datatype of the function identified by `name`. + /// + /// Derived from [Go implementation](https://github.com/influxdata/influxql/blob/1ba470371ec093d57a726b143fe6ccbacf1b452b/ast.go#L4693) + /// and [here](https://github.com/influxdata/influxdb/blob/37088e8f5330bec0f08a376b2cb945d02a296f4e/influxql/query/functions.go#L50). + fn eval_call(&self, call: &Call) -> Result> { + // Evaluate the data types of the arguments + let arg_types: Vec<_> = call + .args + .iter() + .map(|expr| self.eval_type(expr)) + .try_collect()?; + + Ok(match call.name.as_str() { + // See: https://github.com/influxdata/influxdb/blob/e484c4d87193a475466c0285c018d16f168139e6/query/functions.go#L54-L60 + "mean" => Some(VarRefDataType::Float), + "count" => Some(VarRefDataType::Integer), + // These functions return the same type as their first argument + "min" | "max" | "sum" | "first" | "last" | "distinct" => match arg_types.first() { + Some(v) => *v, + None => None, + }, + + // See: https://github.com/influxdata/influxdb/blob/e484c4d87193a475466c0285c018d16f168139e6/query/functions.go#L80 + "median" + | "integral" + | "stddev" + | "derivative" + | "non_negative_derivative" + | "moving_average" + | "exponential_moving_average" + | "double_exponential_moving_average" + | "triple_exponential_moving_average" + | "relative_strength_index" + | "triple_exponential_derivative" + | "kaufmans_efficiency_ratio" + | "kaufmans_adaptive_moving_average" + | "chande_momentum_oscillator" + | "holt_winters" + | "holt_winters_with_fit" => Some(VarRefDataType::Float), + "elapsed" => Some(VarRefDataType::Integer), + + name => self.eval_scalar(name, &arg_types)?, + }) + } + + /// Evaluate the data type of a scalar function + /// + /// See: + /// + /// 💥InfluxQL OG has a bug that it does not evaluate call types correctly, and returns + /// the incorrect type by unconditionally using the first argument. It does not even call the + /// mapper to evaluate scalar functions. We must replicate the InfluxQL OG behaviour, + /// or queries will fail, that would ordinarily succeed. + /// + /// The bug may be traced through the OG source as follows. + /// + /// Prior to executing a `SELECT`, the following steps occur to validate all the field + /// expression types. + /// + /// 1. Calls `validateTypes` to ensure all field data types are valid: + /// + /// + /// 2. Uses a `MultiTypeMapper` to evaluate types, combining: + /// + /// * a `FunctionTypeMapper` for sum, min, max, etc + /// * a `MathTypeMapper` for scalar functions like log, abs, etc + /// + /// ⚠️NOTE: the order is important. `FunctionTypeMapper` is called first. + /// + /// See: + /// + /// 3. Call `EvalType` for each field: + /// + /// See: + /// + /// 4. For fields that have call expressions, the `evalCallExprType` function is ultimately called + /// + /// See: + /// + /// 5. Because the `TypeMapper` is a `CallTypeMapper`, `evalCallExprType` eventually calls `CallType`: + /// + /// See: + /// + /// 6. The `TypeMapper` is a `multiTypeMapper` and thus calls `CallType` for each instance. The first + /// inner call that returns no error and the `typ` is not `Unknown` will be returned to the caller + /// + /// See: + /// + /// 7. Recall, the first `TypeMapper` is `FunctionTypeMapper`, so it's `CallType` is + /// called first. + /// + /// 🪳Here is the bug, which is that `FunctionTypeMapper::CallType` always returns + /// the type of the first argument: + /// + /// See: + fn eval_scalar( + &self, + name: &str, + arg_types: &[Option], + ) -> Result> { + if self.call_type_is_strict { + self.eval_scalar_strict(name, arg_types) + } else { + self.eval_scalar_compatible(arg_types) + } + } + + fn eval_scalar_compatible( + &self, + arg_types: &[Option], + ) -> Result> { + Ok(arg_types.first().and_then(|v| *v)) + } + + fn eval_scalar_strict( + &self, + name: &str, + arg_types: &[Option], + ) -> Result> { + match name { + // These functions require a single numeric as input and return a float + name @ ("sin" | "cos" | "tan" | "atan" | "exp" | "log" | "ln" | "log2" | "log10" + | "sqrt") => { + match arg_types + .first() + .ok_or_else(|| error::map::query(format!("{name} expects 1 argument")))? + { + Some( + VarRefDataType::Float | VarRefDataType::Integer | VarRefDataType::Unsigned, + ) + | None => Ok(Some(VarRefDataType::Float)), + Some(arg0) => error::query(format!( + "invalid argument type for {name}: expected a number, got {arg0}" + )), + } + } + + // These functions require a single float as input and return a float + name @ ("asin" | "acos") => { + match arg_types + .first() + .ok_or_else(|| error::map::query(format!("{name} expects 1 argument")))? + { + Some(VarRefDataType::Float) | None => Ok(Some(VarRefDataType::Float)), + Some(arg0) if self.call_type_is_strict => error::query(format!( + "invalid argument type for {name}: expected a float, got {arg0}" + )), + _ => Ok(None), + } + } + + // These functions require two numeric arguments and return a float + name @ ("atan2" | "pow") => { + let (Some(arg0), Some(arg1)) = (arg_types.first(), arg_types.get(1)) else { + return error::query(format!("{name} expects 2 arguments")); + }; + + match (arg0, arg1) { + (Some( + VarRefDataType::Float + | VarRefDataType::Integer + | VarRefDataType::Unsigned + ) | None, Some( + VarRefDataType::Float + | VarRefDataType::Integer + | VarRefDataType::Unsigned + ) | None) => Ok(Some(VarRefDataType::Float)), + (arg0, arg1) if self.call_type_is_strict => error::query(format!( + "invalid argument types for {name}: expected a number for both arguments, got ({arg0:?}, {arg1:?})" + )), + _ => Ok(None), + } + } + + // These functions return the same data type as their input + name @ ("abs" | "floor" | "ceil" | "round") => { + match arg_types + .first() + .cloned() + .ok_or_else(|| error::map::query(format!("{name} expects 1 argument")))? + { + // Return the same data type as the input + dt @ Some( + VarRefDataType::Float | VarRefDataType::Integer | VarRefDataType::Unsigned, + ) => Ok(dt), + // If the input is unknown, default to float + None => Ok(Some(VarRefDataType::Float)), + Some(arg0) if self.call_type_is_strict => error::query(format!( + "invalid argument type for {name}: expected a number, got {arg0}" + )), + _ => Ok(None), + } + } + _ => Ok(None), + } + } +} + +/// Determine the data type of the binary expression using the left and right operands and the operator +/// +/// This logic is derived from [InfluxQL OG][og]. +/// +/// [og]: https://github.com/influxdata/influxql/blob/802555d6b3a35cd464a6d8afa2a6511002cf3c2c/ast.go#L4192 +fn binary_data_type( + lhs: VarRefDataType, + op: BinaryOperator, + rhs: VarRefDataType, +) -> Option { + use BinaryOperator::*; + use VarRefDataType::{Boolean, Float, Integer, Unsigned}; + + match (lhs, op, rhs) { + // Boolean only supports bitwise operators. + // + // See: + // * https://github.com/influxdata/influxql/blob/802555d6b3a35cd464a6d8afa2a6511002cf3c2c/ast.go#L4210 + (Boolean, BitwiseAnd | BitwiseOr | BitwiseXor, Boolean) => Some(Boolean), + + // A float for either operand is a float result, but only + // support the +, -, * / and % operators. + // + // See: + // * https://github.com/influxdata/influxql/blob/802555d6b3a35cd464a6d8afa2a6511002cf3c2c/ast.go#L4228 + // * https://github.com/influxdata/influxql/blob/802555d6b3a35cd464a6d8afa2a6511002cf3c2c/ast.go#L4285 + // * https://github.com/influxdata/influxql/blob/802555d6b3a35cd464a6d8afa2a6511002cf3c2c/ast.go#L4411 + (Float, Add | Sub | Mul | Div | Mod, Float | Integer | Unsigned) + | (Integer | Unsigned, Add | Sub | Mul | Div | Mod, Float) => Some(Float), + + // Integers using the division operator are always float + // + // See: + // * https://github.com/influxdata/influxql/blob/802555d6b3a35cd464a6d8afa2a6511002cf3c2c/ast.go#L4335-L4340 + // * https://github.com/influxdata/influxdb/blob/3372d3b878ebcba708dc9edfce7ea83cc8152393/query/cursor.go#L178 + (Integer, Div, Integer) => Some(Float), + + // Integer and unsigned types support all operands and + // the result is the same type if both operands are the same. + // + // See: + // * https://github.com/influxdata/influxql/blob/802555d6b3a35cd464a6d8afa2a6511002cf3c2c/ast.go#L4314 + // * https://github.com/influxdata/influxql/blob/802555d6b3a35cd464a6d8afa2a6511002cf3c2c/ast.go#L4489 + (Integer, _, Integer) | (Unsigned, _, Unsigned) => Some(lhs), + + // If either side is unsigned, and the other is integer, + // the result is unsigned for all operators. + // + // See: + // * https://github.com/influxdata/influxql/blob/802555d6b3a35cd464a6d8afa2a6511002cf3c2c/ast.go#L4358 + // * https://github.com/influxdata/influxql/blob/802555d6b3a35cd464a6d8afa2a6511002cf3c2c/ast.go#L4440 + (Unsigned, _, Integer) | (Integer, _, Unsigned) => Some(Unsigned), + + // String or any other combination of operator and operands are invalid + // + // See: + // * https://github.com/influxdata/influxql/blob/802555d6b3a35cd464a6d8afa2a6511002cf3c2c/ast.go#L4562 + _ => None, + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::plan::expr_type_evaluator::binary_data_type; + use crate::plan::ir::DataSource; + use crate::plan::test_utils::MockSchemaProvider; + use assert_matches::assert_matches; + use datafusion::common::DataFusionError; + use influxdb_influxql_parser::expression::VarRefDataType; + use influxdb_influxql_parser::select::Field; + use itertools::iproduct; + + #[test] + fn test_binary_data_type() { + use influxdb_influxql_parser::expression::BinaryOperator::*; + use VarRefDataType::{Boolean, Float, Integer, String, Tag, Timestamp, Unsigned}; + + // Boolean ok + for op in [BitwiseAnd, BitwiseOr, BitwiseXor] { + assert_matches!( + binary_data_type(Boolean, op, Boolean), + Some(VarRefDataType::Boolean) + ); + } + + // Boolean !ok + for op in [Add, Sub, Div, Mul, Mod] { + assert_matches!(binary_data_type(Boolean, op, Boolean), None); + } + + // Float ok + for (op, operand) in iproduct!([Add, Sub, Div, Mul, Mod], [Float, Integer, Unsigned]) { + assert_matches!(binary_data_type(Float, op, operand), Some(Float)); + assert_matches!(binary_data_type(operand, op, Float), Some(Float)); + } + + // Float !ok + for (op, operand) in iproduct!( + [BitwiseAnd, BitwiseOr, BitwiseXor], + [Float, Integer, Unsigned] + ) { + assert_matches!(binary_data_type(Float, op, operand), None); + assert_matches!(binary_data_type(operand, op, Float), None); + } + + // division and integers are special + assert_matches!(binary_data_type(Integer, Div, Integer), Some(Float)); + assert_matches!(binary_data_type(Unsigned, Div, Unsigned), Some(Unsigned)); + + // Integer op Integer | Unsigned op Unsigned + for op in [Add, Sub, Mul, Mod, BitwiseAnd, BitwiseOr, BitwiseXor] { + assert_matches!(binary_data_type(Integer, op, Integer), Some(Integer)); + assert_matches!(binary_data_type(Unsigned, op, Unsigned), Some(Unsigned)); + } + + // Unsigned op Integer | Integer op Unsigned + for op in [Add, Sub, Div, Mul, Mod, BitwiseAnd, BitwiseOr, BitwiseXor] { + assert_matches!(binary_data_type(Integer, op, Unsigned), Some(Unsigned)); + assert_matches!(binary_data_type(Unsigned, op, Integer), Some(Unsigned)); + } + + // Fallible cases + + assert_matches!(binary_data_type(Tag, Add, Tag), None); + assert_matches!(binary_data_type(String, Add, String), None); + assert_matches!(binary_data_type(Timestamp, Add, Timestamp), None); + } + + #[test] + fn test_evaluate_type() { + let namespace = MockSchemaProvider::default(); + + fn evaluate_type( + s: &dyn SchemaProvider, + expr: &str, + from: &[&str], + ) -> Result> { + let from = from + .iter() + .map(ToString::to_string) + .map(DataSource::Table) + .collect::>(); + let Field { expr, .. } = expr.parse().unwrap(); + TypeEvaluator::new(s, &from).eval_type(&expr) + } + + let res = evaluate_type(&namespace, "shared_field0", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Float); + + let res = evaluate_type(&namespace, "shared_tag0", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Tag); + + // Unknown + let res = evaluate_type(&namespace, "not_exists", &["temp_01"]).unwrap(); + assert!(res.is_none()); + + let res = evaluate_type(&namespace, "shared_field0", &["temp_02"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Integer); + + let res = evaluate_type(&namespace, "shared_field0", &["temp_02"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Integer); + + // Same field across multiple measurements resolves to the highest precedence (float) + let res = evaluate_type(&namespace, "shared_field0", &["temp_01", "temp_02"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Float); + + // Explicit cast of integer field to float + let res = evaluate_type(&namespace, "SUM(field_i64::float)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Float); + + // + // Binary expressions + // + + let res = evaluate_type(&namespace, "field_f64 + field_i64", &["all_types"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Float); + + let res = evaluate_type(&namespace, "field_bool | field_bool", &["all_types"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Boolean); + + // Fallible + + // Verify incompatible operators and operator error + let res = evaluate_type(&namespace, "field_f64 & field_i64", &["all_types"]); + assert_matches!(res, Err(DataFusionError::Plan(ref s)) if s == "incompatible operands for operator &: float and integer"); + + // data types for functions + let res = evaluate_type(&namespace, "SUM(field_f64)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Float); + + let res = evaluate_type(&namespace, "SUM(field_i64)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Integer); + + let res = evaluate_type(&namespace, "SUM(field_u64)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Unsigned); + + let res = evaluate_type(&namespace, "MIN(field_f64)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Float); + + let res = evaluate_type(&namespace, "MAX(field_i64)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Integer); + + let res = evaluate_type(&namespace, "FIRST(field_str)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::String); + + let res = evaluate_type(&namespace, "LAST(field_str)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::String); + + let res = evaluate_type(&namespace, "DISTINCT(field_str)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::String); + + let res = evaluate_type(&namespace, "MEAN(field_i64)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Float); + + let res = evaluate_type(&namespace, "MEAN(field_u64)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Float); + + let res = evaluate_type(&namespace, "COUNT(field_f64)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Integer); + + let res = evaluate_type(&namespace, "COUNT(field_i64)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Integer); + + let res = evaluate_type(&namespace, "COUNT(field_u64)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Integer); + + let res = evaluate_type(&namespace, "COUNT(field_str)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Integer); + + // Float functions + for call in [ + "median(field_i64)", + "integral(field_i64)", + "stddev(field_i64)", + "derivative(field_i64)", + "non_negative_derivative(field_i64)", + "moving_average(field_i64, 2)", + "exponential_moving_average(field_i64, 2)", + "double_exponential_moving_average(field_i64, 2)", + "triple_exponential_moving_average(field_i64, 2)", + "relative_strength_index(field_i64, 2)", + "triple_exponential_derivative(field_i64, 2)", + "kaufmans_efficiency_ratio(field_i64, 2)", + "kaufmans_adaptive_moving_average(field_i64, 2)", + "chande_momentum_oscillator(field_i64, 2)", + ] { + let res = evaluate_type(&namespace, call, &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Float); + } + + // holt_winters + let res = evaluate_type( + &namespace, + "holt_winters(mean(field_i64), 2, 3)", + &["temp_01"], + ) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Float); + + // holt_winters_with_fit + let res = evaluate_type( + &namespace, + "holt_winters_with_fit(mean(field_i64), 2, 3)", + &["temp_01"], + ) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Float); + + // Integer functions + let res = evaluate_type(&namespace, "elapsed(field_i64)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Integer); + + // scalar functions + + // These require a single numeric input and return a float + let res = evaluate_type(&namespace, "sin(field_f64)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Float); + + // These require a single float as input and return a float + let res = evaluate_type(&namespace, "asin(field_f64)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Float); + + // These require two numeric arguments as input and return a float + let res = evaluate_type(&namespace, "atan2(field_f64, 3)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Float); + + // These require a numeric argument as input and return the same type + let res = evaluate_type(&namespace, "abs(field_f64)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Float); + let res = evaluate_type(&namespace, "abs(field_i64)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Integer); + let res = evaluate_type(&namespace, "abs(field_u64)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Unsigned); + } + + /// Validate InfluxQL OG compatible behavior for scalar functions + #[test] + fn test_evaluate_type_compat() { + let namespace = MockSchemaProvider::default(); + + fn evaluate_type( + s: &dyn SchemaProvider, + expr: &str, + from: &[&str], + ) -> Result> { + let from = from + .iter() + .map(ToString::to_string) + .map(DataSource::Table) + .collect::>(); + let Field { expr, .. } = expr.parse().unwrap(); + TypeEvaluator::new(s, &from).eval_type(&expr) + } + + let res = evaluate_type(&namespace, "sin(field_i64)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Integer); + let res = evaluate_type(&namespace, "sin(field_str)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::String); + + let res = evaluate_type(&namespace, "asin(field_i64)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Integer); + + // invalid number of arguments, still returns data type of first arg + let res = evaluate_type(&namespace, "atan2(field_f64)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Float); + + let res = evaluate_type(&namespace, "atan2(field_str, 3)", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::String); + let res = evaluate_type(&namespace, "atan2(field_i64, 'str')", &["temp_01"]) + .unwrap() + .unwrap(); + assert_matches!(res, VarRefDataType::Integer); + } + + /// Validates `TypeEvaluator` when in strict mode. + #[test] + fn test_evaluate_type_strict() { + let namespace = MockSchemaProvider::default(); + + fn evaluate_type( + s: &dyn SchemaProvider, + expr: &str, + from: &[&str], + ) -> Result> { + let from = from + .iter() + .map(ToString::to_string) + .map(DataSource::Table) + .collect::>(); + let Field { expr, .. } = expr.parse().unwrap(); + TypeEvaluator::new_strict(s, &from).eval_type(&expr) + } + + // In struct mode, these scalar functions should return an error when the arguments are an + // invalid data type. + + evaluate_type(&namespace, "sin(field_str)", &["temp_01"]).unwrap_err(); + evaluate_type(&namespace, "asin(field_i64)", &["temp_01"]).unwrap_err(); + evaluate_type(&namespace, "atan2(field_f64)", &["temp_01"]).unwrap_err(); + evaluate_type(&namespace, "atan2(field_str, 3)", &["temp_01"]).unwrap_err(); + evaluate_type(&namespace, "atan2(field_i64, 'str')", &["temp_01"]).unwrap_err(); + evaluate_type(&namespace, "abs(field_str)", &["temp_01"]).unwrap_err(); + } +} diff --git a/iox_query_influxql/src/plan/field.rs b/iox_query_influxql/src/plan/field.rs new file mode 100644 index 0000000..cf989b1 --- /dev/null +++ b/iox_query_influxql/src/plan/field.rs @@ -0,0 +1,191 @@ +use crate::plan::ir::Field; +use influxdb_influxql_parser::expression::{Call, Expr, VarRef}; +use influxdb_influxql_parser::visit::{Recursion, Visitable, Visitor}; +use std::ops::Deref; + +/// Returns the name of the field. +/// +/// Prefers the alias if set, otherwise derives the name +/// from [Expr::VarRef] or [Expr::Call]. Finally, if neither +/// are available, falls back to an empty string. +/// +/// Derived from [Go implementation](https://github.com/influxdata/influxql/blob/1ba470371ec093d57a726b143fe6ccbacf1b452b/ast.go#L3326-L3328) +pub(crate) fn field_name(f: &influxdb_influxql_parser::select::Field) -> String { + if let Some(alias) = &f.alias { + return alias.deref().to_string(); + } + + let mut expr = &f.expr; + loop { + expr = match expr { + Expr::Call(Call { name, .. }) => return name.clone(), + Expr::Nested(nested) => nested, + Expr::Binary { .. } => return binary_expr_name(&f.expr), + Expr::Distinct(_) => return "distinct".to_string(), + Expr::VarRef(VarRef { name, .. }) => return name.deref().into(), + Expr::Wildcard(_) | Expr::BindParameter(_) | Expr::Literal(_) => return "".to_string(), + }; + } +} + +/// Returns the expression that matches the field name. +/// +/// If the name matches one of the arguments to +/// "top" or "bottom", the variable reference inside of the function is returned. +/// +/// Derive from [this implementation](https://github.com/influxdata/influxql/blob/1ba470371ec093d57a726b143fe6ccbacf1b452b/ast.go#L1725) +/// +/// **NOTE** +/// +/// This implementation duplicates the behavior of the original implementation, including skipping the +/// first argument. It is likely the original intended to skip the _last_ argument, which is the number +/// of rows. +pub(super) fn field_by_name<'a>(fields: &'a [Field], name: &str) -> Option<&'a Field> { + fields.iter().find(|f| f.name == name || match &f.expr { + Expr::Call(Call{ name: func_name, args }) if (func_name == "top" + || func_name == "bottom") + && args.len() > 2 => + args[1..].iter().any(|f| matches!(f, Expr::VarRef(VarRef{ name: field_name, .. }) if field_name.as_str() == name)), + _ => false, + }) +} + +struct BinaryExprNameVisitor<'a>(&'a mut Vec); + +impl<'a> Visitor for BinaryExprNameVisitor<'a> { + type Error = (); + + fn pre_visit_var_ref(self, n: &VarRef) -> Result, Self::Error> { + self.0.push(n.name.to_string()); + Ok(Recursion::Continue(self)) + } + + fn pre_visit_call(self, n: &Call) -> Result, Self::Error> { + self.0.push(n.name.clone()); + Ok(Recursion::Stop(self)) + } +} + +/// Returns the name of a binary expression by concatenating +/// the names of any [Expr::VarRef] and [Expr::Call] with underscores. +/// +/// Derived from [Go implementation](https://github.com/influxdata/influxql/blob/1ba470371ec093d57a726b143fe6ccbacf1b452b/ast.go#L3729-L3731) +fn binary_expr_name(expr: &Expr) -> String { + let mut names = Vec::new(); + let vis = BinaryExprNameVisitor(&mut names); + expr.accept(vis).unwrap(); // It is not expected to fail + names.join("_") +} + +#[cfg(test)] +mod test { + use crate::plan::field::{field_by_name, field_name}; + use crate::plan::ir; + use assert_matches::assert_matches; + use influxdb_influxql_parser::select::Field; + + #[test] + fn test_field_name() { + let f: Field = "usage".parse().unwrap(); + assert_eq!(field_name(&f), "usage"); + + let f: Field = "usage as u2".parse().unwrap(); + assert_eq!(field_name(&f), "u2"); + + let f: Field = "(usage)".parse().unwrap(); + assert_eq!(field_name(&f), "usage"); + + let f: Field = "COUNT(usage)".parse().unwrap(); + assert_eq!(field_name(&f), "count"); + + let f: Field = "COUNT(usage) + SUM(usage_idle)".parse().unwrap(); + assert_eq!(field_name(&f), "count_sum"); + + let f: Field = "1+2".parse().unwrap(); + assert_eq!(field_name(&f), ""); + + let f: Field = "1 + usage".parse().unwrap(); + assert_eq!(field_name(&f), "usage"); + + let f: Field = "/reg/".parse().unwrap(); + assert_eq!(field_name(&f), ""); + + let f: Field = "DISTINCT usage".parse().unwrap(); + assert_eq!(field_name(&f), "distinct"); + + let f: Field = "-usage".parse().unwrap(); + assert_eq!(field_name(&f), "usage"); + + // Doesn't quote keyword + let f: Field = "\"user\"".parse().unwrap(); + assert_eq!(field_name(&f), "user"); + } + + #[test] + fn test_field_by_name() { + fn parse_fields(exprs: Vec<&str>) -> Vec { + exprs + .iter() + .map(|s| { + let f: Field = s.parse().unwrap(); + let name = field_name(&f); + let data_type = None; + ir::Field { + expr: f.expr, + name, + data_type, + } + }) + .collect() + } + let stmt = parse_fields(vec!["usage", "idle"]); + assert_eq!( + format!("{}", field_by_name(&stmt, "usage").unwrap()), + "usage AS usage" + ); + + let stmt = parse_fields(vec!["usage as foo", "usage"]); + assert_eq!( + format!("{}", field_by_name(&stmt, "foo").unwrap()), + "usage AS foo" + ); + + let stmt = parse_fields(vec!["top(idle, usage, 5)", "usage"]); + assert_eq!( + format!("{}", field_by_name(&stmt, "usage").unwrap()), + "top(idle, usage, 5) AS top" + ); + + let stmt = parse_fields(vec!["bottom(idle, usage, 5)", "usage"]); + assert_eq!( + format!("{}", field_by_name(&stmt, "usage").unwrap()), + "bottom(idle, usage, 5) AS bottom" + ); + + // TOP is in uppercase, to ensure we can expect the function name to be + // uniformly lowercase. + let stmt = parse_fields(vec!["TOP(idle, usage, 5) as foo", "usage"]); + assert_eq!( + format!("{}", field_by_name(&stmt, "usage").unwrap()), + "top(idle, usage, 5) AS foo" + ); + assert_eq!( + format!("{}", field_by_name(&stmt, "foo").unwrap()), + "top(idle, usage, 5) AS foo" + ); + + // Not exists + + let stmt = parse_fields(vec!["usage", "idle"]); + assert_matches!(field_by_name(&stmt, "bar"), None); + + // Does not match name by first argument to top or bottom, per + // bug in original implementation. + let stmt = parse_fields(vec!["top(foo, usage, 5)", "idle"]); + assert_matches!(field_by_name(&stmt, "foo"), None); + assert_eq!( + format!("{}", field_by_name(&stmt, "idle").unwrap()), + "idle AS idle" + ); + } +} diff --git a/iox_query_influxql/src/plan/field_mapper.rs b/iox_query_influxql/src/plan/field_mapper.rs new file mode 100644 index 0000000..24d09d9 --- /dev/null +++ b/iox_query_influxql/src/plan/field_mapper.rs @@ -0,0 +1,92 @@ +use crate::plan::ir::TagSet; +use crate::plan::var_ref::{field_type_to_var_ref_data_type, influx_type_to_var_ref_data_type}; +use crate::plan::SchemaProvider; +use influxdb_influxql_parser::expression::VarRefDataType; +use schema::InfluxColumnType; +use std::collections::HashMap; + +pub(crate) type FieldTypeMap = HashMap; + +pub(crate) fn field_and_dimensions( + s: &dyn SchemaProvider, + name: &str, +) -> Option<(FieldTypeMap, TagSet)> { + s.table_schema(name).map(|iox| { + let mut field_set = FieldTypeMap::new(); + let mut tag_set = TagSet::new(); + + for col in iox.iter() { + match col { + (InfluxColumnType::Field(ft), f) => { + field_set.insert(f.name().to_owned(), field_type_to_var_ref_data_type(ft)); + } + (InfluxColumnType::Tag, f) => { + tag_set.insert(f.name().to_owned()); + } + (InfluxColumnType::Timestamp, _) => {} + } + } + (field_set, tag_set) + }) +} + +pub(crate) fn map_type( + s: &dyn SchemaProvider, + measurement_name: &str, + field: &str, +) -> Option { + s.table_schema(measurement_name).and_then(|iox| { + iox.field_by_name(field) + .and_then(|(dt, _)| influx_type_to_var_ref_data_type(Some(dt))) + }) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::plan::test_utils::MockSchemaProvider; + use assert_matches::assert_matches; + + #[test] + fn test_schema_field_mapper() { + let namespace = MockSchemaProvider::default(); + + // Measurement exists + let (field_set, tag_set) = field_and_dimensions(&namespace, "cpu").unwrap(); + assert_eq!( + field_set, + FieldTypeMap::from([ + ("usage_user".to_string(), VarRefDataType::Float), + ("usage_system".to_string(), VarRefDataType::Float), + ("usage_idle".to_string(), VarRefDataType::Float), + ]) + ); + assert_eq!( + tag_set, + TagSet::from(["cpu".to_string(), "host".to_string(), "region".to_string()]) + ); + + // Measurement does not exist + assert!(field_and_dimensions(&namespace, "cpu2").is_none()); + + // `map_type` API calls + + // Returns expected type + assert_matches!( + map_type(&namespace, "cpu", "usage_user"), + Some(VarRefDataType::Float) + ); + assert_matches!( + map_type(&namespace, "cpu", "host"), + Some(VarRefDataType::Tag) + ); + assert_matches!( + map_type(&namespace, "cpu", "time"), + Some(VarRefDataType::Timestamp) + ); + // Returns None for nonexistent field + assert!(map_type(&namespace, "cpu", "nonexistent").is_none()); + // Returns None for nonexistent measurement + assert!(map_type(&namespace, "nonexistent", "usage").is_none()); + } +} diff --git a/iox_query_influxql/src/plan/ir.rs b/iox_query_influxql/src/plan/ir.rs new file mode 100644 index 0000000..7ee811d --- /dev/null +++ b/iox_query_influxql/src/plan/ir.rs @@ -0,0 +1,235 @@ +//! Defines data structures which represent an InfluxQL +//! statement after it has been processed + +use crate::error; +use crate::plan::rewriter::ProjectionType; +use datafusion::common::Result; +use influxdb_influxql_parser::common::{ + LimitClause, MeasurementName, OffsetClause, OrderByClause, QualifiedMeasurementName, + WhereClause, +}; +use influxdb_influxql_parser::expression::{ConditionalExpression, Expr}; +use influxdb_influxql_parser::select::{ + FieldList, FillClause, FromMeasurementClause, GroupByClause, MeasurementSelection, + SelectStatement, TimeZoneClause, +}; +use influxdb_influxql_parser::time_range::TimeRange; +use schema::{InfluxColumnType, Schema}; +use std::collections::HashSet; +use std::fmt::{Display, Formatter}; + +use super::SchemaProvider; + +/// A set of tag keys. +pub(super) type TagSet = HashSet; + +/// Represents a validated and normalized top-level [`SelectStatement`]. +#[derive(Debug, Default, Clone)] +pub(super) struct SelectQuery { + pub(super) select: Select, +} + +#[derive(Debug, Default, Clone)] +pub(super) struct Select { + /// The projection type of the selection. + pub(super) projection_type: ProjectionType, + + /// The interval derived from the arguments to the `TIME` function + /// when a `GROUP BY` clause is declared with `TIME`. + pub(super) interval: Option, + + /// The number of additional intervals that must be read + /// for queries that group by time and use window functions such as + /// `DIFFERENCE` or `DERIVATIVE`. This ensures data for the first + /// window is available. + /// + /// See: + pub(super) extra_intervals: usize, + + /// Projection clause of the selection. + pub(super) fields: Vec, + + /// A list of data sources for the selection. + pub(super) from: Vec, + + /// A conditional expression to filter the selection, excluding any predicates for the `time` + /// column. + pub(super) condition: Option, + + /// The time range derived from the `WHERE` clause of the `SELECT` statement. + pub(super) time_range: TimeRange, + + /// The GROUP BY clause of the selection. + pub(super) group_by: Option, + + /// The set of possible tags for the selection, by combining + /// the tag sets of all inputs via the `FROM` clause. + pub(super) tag_set: TagSet, + + /// The [fill] clause specifies the fill behaviour for the selection. If the value is [`None`], + /// it is the same behavior as `fill(null)`. + /// + /// [fill]: https://docs.influxdata.com/influxdb/v1.8/query_language/explore-data/#group-by-time-intervals-and-fill + pub(super) fill: Option, + + /// Configures the ordering of the selection by time. + pub(super) order_by: Option, + + /// A value to restrict the number of rows returned. + pub(super) limit: Option, + + /// A value to specify an offset to start retrieving rows. + pub(super) offset: Option, + + /// The timezone for the query, specified as [`tz('