diff --git a/docker/sql_setup.sh b/docker/sql_setup.sh index 422dcbda9..0315ac805 100755 --- a/docker/sql_setup.sh +++ b/docker/sql_setup.sh @@ -96,4 +96,5 @@ psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" <<-EOSQL CREATE ROLE ssl_user LOGIN; CREATE EXTENSION hstore; CREATE EXTENSION citext; + CREATE EXTENSION ltree; EOSQL diff --git a/postgres-protocol/Cargo.toml b/postgres-protocol/Cargo.toml index 2010e88ad..a4716907b 100644 --- a/postgres-protocol/Cargo.toml +++ b/postgres-protocol/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres-protocol" -version = "0.6.3" +version = "0.6.4" authors = ["Steven Fackler "] edition = "2018" description = "Low level Postgres protocol APIs" diff --git a/postgres-protocol/src/types/mod.rs b/postgres-protocol/src/types/mod.rs index a595f5a30..05f515f76 100644 --- a/postgres-protocol/src/types/mod.rs +++ b/postgres-protocol/src/types/mod.rs @@ -1059,3 +1059,60 @@ impl Inet { self.netmask } } + +/// Serializes a Postgres ltree string +#[inline] +pub fn ltree_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an ltree string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} + +/// Deserialize a Postgres ltree string +#[inline] +pub fn ltree_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the ltree per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("ltree version 1 only supported".into()), + } +} + +/// Serializes a Postgres lquery string +#[inline] +pub fn lquery_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an lquery string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} + +/// Deserialize a Postgres lquery string +#[inline] +pub fn lquery_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the lquery per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("lquery version 1 only supported".into()), + } +} + +/// Serializes a Postgres ltxtquery string +#[inline] +pub fn ltxtquery_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an ltxtquery string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} + +/// Deserialize a Postgres ltxtquery string +#[inline] +pub fn ltxtquery_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the ltxtquery per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("ltxtquery version 1 only supported".into()), + } +} diff --git a/postgres-protocol/src/types/test.rs b/postgres-protocol/src/types/test.rs index 7c20cf3ed..6f1851fc2 100644 --- a/postgres-protocol/src/types/test.rs +++ b/postgres-protocol/src/types/test.rs @@ -1,4 +1,4 @@ -use bytes::BytesMut; +use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; use std::collections::HashMap; @@ -156,3 +156,87 @@ fn non_null_array() { assert_eq!(array.dimensions().collect::>().unwrap(), dimensions); assert_eq!(array.values().collect::>().unwrap(), values); } + +#[test] +fn ltree_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + let mut buf = BytesMut::new(); + + ltree_to_sql("A.B.C", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn ltree_str() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(matches!(ltree_from_sql(query.as_slice()), Ok(_))) +} + +#[test] +fn ltree_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(matches!(ltree_from_sql(query.as_slice()), Err(_))) +} + +#[test] +fn lquery_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + let mut buf = BytesMut::new(); + + lquery_to_sql("A.B.C", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn lquery_str() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(matches!(lquery_from_sql(query.as_slice()), Ok(_))) +} + +#[test] +fn lquery_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(matches!(lquery_from_sql(query.as_slice()), Err(_))) +} + +#[test] +fn ltxtquery_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("a & b*".as_bytes()); + + let mut buf = BytesMut::new(); + + ltree_to_sql("a & b*", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn ltxtquery_str() { + let mut query = vec![1u8]; + query.extend_from_slice("a & b*".as_bytes()); + + assert!(matches!(ltree_from_sql(query.as_slice()), Ok(_))) +} + +#[test] +fn ltxtquery_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("a & b*".as_bytes()); + + assert!(matches!(ltree_from_sql(query.as_slice()), Err(_))) +} diff --git a/postgres-types/Cargo.toml b/postgres-types/Cargo.toml index 7eca3fbcf..000d71ea0 100644 --- a/postgres-types/Cargo.toml +++ b/postgres-types/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres-types" -version = "0.2.2" +version = "0.2.3" authors = ["Steven Fackler "] edition = "2018" license = "MIT/Apache-2.0" @@ -14,6 +14,7 @@ categories = ["database"] derive = ["postgres-derive"] array-impls = ["array-init"] with-bit-vec-0_6 = ["bit-vec-06"] +with-cidr-0_2 = ["cidr-02"] with-chrono-0_4 = ["chrono-04"] with-eui48-0_4 = ["eui48-04"] with-eui48-1 = ["eui48-1"] @@ -27,12 +28,13 @@ with-time-0_3 = ["time-03"] [dependencies] bytes = "1.0" fallible-iterator = "0.2" -postgres-protocol = { version = "0.6.1", path = "../postgres-protocol" } +postgres-protocol = { version = "0.6.4", path = "../postgres-protocol" } postgres-derive = { version = "0.4.0", optional = true, path = "../postgres-derive" } array-init = { version = "2", optional = true } bit-vec-06 = { version = "0.6", package = "bit-vec", optional = true } chrono-04 = { version = "0.4.16", package = "chrono", default-features = false, features = ["clock"], optional = true } +cidr-02 = { version = "0.2", package = "cidr", optional = true } eui48-04 = { version = "0.4", package = "eui48", optional = true } eui48-1 = { version = "1.0", package = "eui48", optional = true } geo-types-06 = { version = "0.6", package = "geo-types", optional = true } diff --git a/postgres-types/src/cidr_02.rs b/postgres-types/src/cidr_02.rs new file mode 100644 index 000000000..2de952c3c --- /dev/null +++ b/postgres-types/src/cidr_02.rs @@ -0,0 +1,44 @@ +use bytes::BytesMut; +use cidr_02::{IpCidr, IpInet}; +use postgres_protocol::types; +use std::error::Error; + +use crate::{FromSql, IsNull, ToSql, Type}; + +impl<'a> FromSql<'a> for IpCidr { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let inet = types::inet_from_sql(raw)?; + Ok(IpCidr::new(inet.addr(), inet.netmask())?) + } + + accepts!(CIDR); +} + +impl ToSql for IpCidr { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + types::inet_to_sql(self.first_address(), self.network_length(), w); + Ok(IsNull::No) + } + + accepts!(CIDR); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for IpInet { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let inet = types::inet_from_sql(raw)?; + Ok(IpInet::new(inet.addr(), inet.netmask())?) + } + + accepts!(INET); +} + +impl ToSql for IpInet { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + types::inet_to_sql(self.address(), self.network_length(), w); + Ok(IsNull::No) + } + + accepts!(INET); + to_sql_checked!(); +} diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index 0247b90b7..d029d3948 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -212,6 +212,8 @@ where mod bit_vec_06; #[cfg(feature = "with-chrono-0_4")] mod chrono_04; +#[cfg(feature = "with-cidr-0_2")] +mod cidr_02; #[cfg(feature = "with-eui48-0_4")] mod eui48_04; #[cfg(feature = "with-eui48-1")] @@ -405,6 +407,7 @@ impl WrongType { /// | `f32` | REAL | /// | `f64` | DOUBLE PRECISION | /// | `&str`/`String` | VARCHAR, CHAR(n), TEXT, CITEXT, NAME, UNKNOWN | +/// | | LTREE, LQUERY, LTXTQUERY | /// | `&[u8]`/`Vec` | BYTEA | /// | `HashMap>` | HSTORE | /// | `SystemTime` | TIMESTAMP, TIMESTAMP WITH TIME ZONE | @@ -436,6 +439,8 @@ impl WrongType { /// | `uuid::Uuid` | UUID | /// | `bit_vec::BitVec` | BIT, VARBIT | /// | `eui48::MacAddress` | MACADDR | +/// | `cidr::InetCidr` | CIDR | +/// | `cidr::InetAddr` | INET | /// /// # Nullability /// @@ -590,8 +595,8 @@ impl<'a> FromSql<'a> for &'a [u8] { } impl<'a> FromSql<'a> for String { - fn from_sql(_: &Type, raw: &'a [u8]) -> Result> { - types::text_from_sql(raw).map(ToString::to_string) + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result> { + <&str as FromSql>::from_sql(ty, raw).map(ToString::to_string) } fn accepts(ty: &Type) -> bool { @@ -600,8 +605,8 @@ impl<'a> FromSql<'a> for String { } impl<'a> FromSql<'a> for Box { - fn from_sql(_: &Type, raw: &'a [u8]) -> Result, Box> { - types::text_from_sql(raw) + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result, Box> { + <&str as FromSql>::from_sql(ty, raw) .map(ToString::to_string) .map(String::into_boxed_str) } @@ -612,14 +617,26 @@ impl<'a> FromSql<'a> for Box { } impl<'a> FromSql<'a> for &'a str { - fn from_sql(_: &Type, raw: &'a [u8]) -> Result<&'a str, Box> { - types::text_from_sql(raw) + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<&'a str, Box> { + match *ty { + ref ty if ty.name() == "ltree" => types::ltree_from_sql(raw), + ref ty if ty.name() == "lquery" => types::lquery_from_sql(raw), + ref ty if ty.name() == "ltxtquery" => types::ltxtquery_from_sql(raw), + _ => types::text_from_sql(raw), + } } fn accepts(ty: &Type) -> bool { match *ty { Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, - ref ty if ty.name() == "citext" => true, + ref ty + if (ty.name() == "citext" + || ty.name() == "ltree" + || ty.name() == "lquery" + || ty.name() == "ltxtquery") => + { + true + } _ => false, } } @@ -723,6 +740,7 @@ pub enum IsNull { /// | `f32` | REAL | /// | `f64` | DOUBLE PRECISION | /// | `&str`/`String` | VARCHAR, CHAR(n), TEXT, CITEXT, NAME | +/// | | LTREE, LQUERY, LTXTQUERY | /// | `&[u8]`/`Vec` | BYTEA | /// | `HashMap>` | HSTORE | /// | `SystemTime` | TIMESTAMP, TIMESTAMP WITH TIME ZONE | @@ -920,15 +938,27 @@ impl ToSql for Vec { } impl<'a> ToSql for &'a str { - fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { - types::text_to_sql(*self, w); + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + match *ty { + ref ty if ty.name() == "ltree" => types::ltree_to_sql(*self, w), + ref ty if ty.name() == "lquery" => types::lquery_to_sql(*self, w), + ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(*self, w), + _ => types::text_to_sql(*self, w), + } Ok(IsNull::No) } fn accepts(ty: &Type) -> bool { match *ty { Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, - ref ty if ty.name() == "citext" => true, + ref ty + if (ty.name() == "citext" + || ty.name() == "ltree" + || ty.name() == "lquery" + || ty.name() == "ltxtquery") => + { + true + } _ => false, } } diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 94371af51..82e71fb1c 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-postgres" -version = "0.7.5" +version = "0.7.6" authors = ["Steven Fackler "] edition = "2018" license = "MIT/Apache-2.0" @@ -50,8 +50,8 @@ parking_lot = "0.12" percent-encoding = "2.0" pin-project-lite = "0.2" phf = "0.10" -postgres-protocol = { version = "0.6.1", path = "../postgres-protocol" } -postgres-types = { version = "0.2.2", path = "../postgres-types" } +postgres-protocol = { version = "0.6.4", path = "../postgres-protocol" } +postgres-types = { version = "0.2.3", path = "../postgres-types" } socket2 = "0.4" tokio = { version = "1.0", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } diff --git a/tokio-postgres/tests/test/types/mod.rs b/tokio-postgres/tests/test/types/mod.rs index 604e2de32..de700d791 100644 --- a/tokio-postgres/tests/test/types/mod.rs +++ b/tokio-postgres/tests/test/types/mod.rs @@ -648,3 +648,90 @@ async fn inet() { ) .await; } + +#[tokio::test] +async fn ltree() { + test_type( + "ltree", + &[(Some("b.c.d".to_owned()), "'b.c.d'"), (None, "NULL")], + ) + .await; +} + +#[tokio::test] +async fn ltree_any() { + test_type( + "ltree[]", + &[ + (Some(vec![]), "ARRAY[]"), + (Some(vec!["a.b.c".to_string()]), "ARRAY['a.b.c']"), + ( + Some(vec!["a.b.c".to_string(), "e.f.g".to_string()]), + "ARRAY['a.b.c','e.f.g']", + ), + (None, "NULL"), + ], + ) + .await; +} + +#[tokio::test] +async fn lquery() { + test_type( + "lquery", + &[ + (Some("b.c.d".to_owned()), "'b.c.d'"), + (Some("b.c.*".to_owned()), "'b.c.*'"), + (Some("b.*{1,2}.d|e".to_owned()), "'b.*{1,2}.d|e'"), + (None, "NULL"), + ], + ) + .await; +} + +#[tokio::test] +async fn lquery_any() { + test_type( + "lquery[]", + &[ + (Some(vec![]), "ARRAY[]"), + (Some(vec!["b.c.*".to_string()]), "ARRAY['b.c.*']"), + ( + Some(vec!["b.c.*".to_string(), "b.*{1,2}.d|e".to_string()]), + "ARRAY['b.c.*','b.*{1,2}.d|e']", + ), + (None, "NULL"), + ], + ) + .await; +} + +#[tokio::test] +async fn ltxtquery() { + test_type( + "ltxtquery", + &[ + (Some("b & c & d".to_owned()), "'b & c & d'"), + (Some("b@* & !c".to_owned()), "'b@* & !c'"), + (None, "NULL"), + ], + ) + .await; +} + +#[tokio::test] +async fn ltxtquery_any() { + test_type( + "ltxtquery[]", + &[ + (Some(vec![]), "ARRAY[]"), + (Some(vec!["b & c & d".to_string()]), "ARRAY['b & c & d']"), + ( + Some(vec!["b & c & d".to_string(), "b@* & !c".to_string()]), + "ARRAY['b & c & d','b@* & !c']", + ), + (None, "NULL"), + ], + ) + .await; +}