Skip to content

Commit

Permalink
Add from/to_tag_and_arg (tweag#1939)
Browse files Browse the repository at this point in the history
* Add from/to_tag_and_arg

This commit adds two functions to convert enums to and back from
records, as a tag and an optional argument. Such functions are useful to
handle enums in a general, dynamic way, while pattern matching requires
to know in advance the possible tags.

Additionally, we also implement a `map` function, which can be derived
from the conversions above.

* Update core/stdlib/std.ncl

Co-authored-by: jneem <[email protected]>

* Update core/stdlib/std.ncl

Co-authored-by: jneem <[email protected]>

* Update core/stdlib/std.ncl

Co-authored-by: jneem <[email protected]>

---------

Co-authored-by: jneem <[email protected]>
  • Loading branch information
yannham and jneem authored Jun 10, 2024
1 parent 9fe8b98 commit 95967eb
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 14 deletions.
23 changes: 22 additions & 1 deletion core/src/eval/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//! On the other hand, the functions `process_unary_operation` and `process_binary_operation`
//! receive evaluated operands and implement the actual semantics of operators.
use super::{
cache::lazy::Thunk,
merge::{self, MergeMode},
stack::StrAccData,
subst, Cache, Closure, Environment, ImportResolver, VirtualMachine,
Expand Down Expand Up @@ -1216,7 +1217,7 @@ impl<R: ImportResolver, C: Cache> VirtualMachine<R, C> {
))
}
}
UnaryOp::EnumUnwrapVariant => {
UnaryOp::EnumGetArg => {
if let Term::EnumVariant { arg, .. } = &*t {
Ok(Closure {
body: arg.clone(),
Expand All @@ -1226,6 +1227,26 @@ impl<R: ImportResolver, C: Cache> VirtualMachine<R, C> {
Err(mk_type_error!("enum_unwrap_variant", "Enum variant"))
}
}
UnaryOp::EnumMakeVariant => {
let Term::Str(tag) = &*t else {
return Err(mk_type_error!("enum/make_variant", "String"));
};

let (arg_clos, _) = self.stack.pop_arg(&self.cache).ok_or_else(|| {
EvalError::NotEnoughArgs(2, String::from("enum/make_variant"), pos)
})?;
let arg_pos = arg_clos.body.pos;
let arg = RichTerm::new(Term::Closure(Thunk::new(arg_clos)), arg_pos);

Ok(Closure::atomic_closure(RichTerm::new(
Term::EnumVariant {
tag: LocIdent::new(tag).with_pos(pos),
arg,
attrs: EnumVariantAttrs { closurized: true },
},
pos_op_inh,
)))
}
UnaryOp::EnumGetTag => match &*t {
Term::EnumVariant { tag, .. } | Term::Enum(tag) => Ok(Closure::atomic_closure(
RichTerm::new(Term::Enum(*tag), pos_op_inh),
Expand Down
6 changes: 4 additions & 2 deletions core/src/parser/grammar.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,8 @@ UOp: UnaryOp = {
})
}
},
"enum/unwrap_variant" => UnaryOp::EnumUnwrapVariant,
"enum/get_arg" => UnaryOp::EnumGetArg,
"enum/make_variant" => UnaryOp::EnumMakeVariant,
"enum/is_variant" => UnaryOp::EnumIsVariant,
"enum/get_tag" => UnaryOp::EnumGetTag,
}
Expand Down Expand Up @@ -1585,7 +1586,8 @@ extern {
"label/push_diag" => Token::Normal(NormalToken::LabelPushDiag),
"array/slice" => Token::Normal(NormalToken::ArraySlice),
"eval_nix" => Token::Normal(NormalToken::EvalNix),
"enum/unwrap_variant" => Token::Normal(NormalToken::EnumUnwrapVariant),
"enum/get_arg" => Token::Normal(NormalToken::EnumGetArg),
"enum/make_variant" => Token::Normal(NormalToken::EnumMakeVariant),
"enum/is_variant" => Token::Normal(NormalToken::EnumIsVariant),
"enum/get_tag" => Token::Normal(NormalToken::EnumGetTag),
"pattern_branch" => Token::Normal(NormalToken::PatternBranch),
Expand Down
6 changes: 4 additions & 2 deletions core/src/parser/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,10 @@ pub enum NormalToken<'input> {
NumberFromString,
#[token("%enum/from_string%")]
EnumFromString,
#[token("%enum/unwrap_variant%")]
EnumUnwrapVariant,
#[token("%enum/get_arg%")]
EnumGetArg,
#[token("%enum/make_variant%")]
EnumMakeVariant,
#[token("%enum/is_variant%")]
EnumIsVariant,
#[token("%enum/get_tag%")]
Expand Down
10 changes: 7 additions & 3 deletions core/src/term/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1415,8 +1415,11 @@ pub enum UnaryOp {
#[cfg(feature = "nix-experimental")]
EvalNix,

/// Unwrap the variant from an enum: `%unwrap_enum_variant% ('Foo t) := t`
EnumUnwrapVariant,
/// Retrive the argument from an enum variant: `%enum/get_arg% ('Foo t) := t`
EnumGetArg,
/// Create an enum variant from a tag and an argument. This operator is strict in tag and
/// return a function that can be further applied to an argument.
EnumMakeVariant,
/// Return true if the given parameter is an enum variant.
EnumIsVariant,
/// Extract the tag from an enum tag or an enum variant.
Expand Down Expand Up @@ -1490,7 +1493,8 @@ impl fmt::Display for UnaryOp {
#[cfg(feature = "nix-experimental")]
EvalNix => write!(f, "eval_nix"),

EnumUnwrapVariant => write!(f, "enum/unwrap_variant"),
EnumGetArg => write!(f, "enum/get_arg"),
EnumMakeVariant => write!(f, "enum/make_variant"),
EnumIsVariant => write!(f, "enum/is_variant"),
EnumGetTag => write!(f, "enum/get_tag"),

Expand Down
2 changes: 1 addition & 1 deletion core/src/term/pattern/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ impl CompilePart for EnumPattern {
if_condition,
make::let_in(
value_id,
make::op1(UnaryOp::EnumUnwrapVariant, Term::Var(value_id)),
make::op1(UnaryOp::EnumGetArg, Term::Var(value_id)),
pat.compile_part(value_id, bindings_id),
),
Term::Null,
Expand Down
11 changes: 9 additions & 2 deletions core/src/typecheck/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,15 @@ pub fn get_uop_type(
// primop.
// This isn't a problem, as this operator is mostly internal and pattern matching should be
// used to destructure enum variants.
UnaryOp::EnumUnwrapVariant => (mk_uniftype::dynamic(), mk_uniftype::dynamic()),
// Same as `EnumUnwrapVariant` just above.
// Dyn -> Dyn
UnaryOp::EnumGetArg => (mk_uniftype::dynamic(), mk_uniftype::dynamic()),
// String -> (Dyn -> Dyn)
UnaryOp::EnumMakeVariant => (
mk_uniftype::str(),
mk_uniftype::arrow(mk_uniftype::dynamic(), mk_uniftype::dynamic()),
),
// Same as `EnumGetArg` just above.
// Dyn -> Dyn
UnaryOp::EnumGetTag => (mk_uniftype::dynamic(), mk_uniftype::dynamic()),
// Note that is_variant breaks parametricity, so it can't get a polymorphic type.
// Dyn -> Bool
Expand Down
96 changes: 94 additions & 2 deletions core/stdlib/std.ncl
Original file line number Diff line number Diff line change
Expand Up @@ -1566,15 +1566,31 @@

```nickel
('foo | std.enum.Tag) =>
`foo
'foo
('FooBar | std.enum.Tag) =>
`FooBar
'FooBar
("tag" | std.enum.Tag) =>
error
```
"%
= std.contract.from_predicate is_enum_tag,

Enum
| doc m%"
Enforces that the value is an enum (either a tag or a variant).

# Examples

```nickel
('Foo | std.enum.Enum) =>
'Foo
('Bar 5 | std.enum.Enum) =>
'Bar 5
("tag" | std.enum.Enum) =>
error
"%
= std.contract.from_predicate std.is_enum,

TagOrString
| doc m%%"
Accepts both enum tags and strings. Strings are automatically
Expand Down Expand Up @@ -1670,6 +1686,82 @@
```
"%
= fun value => %enum/is_variant% value,

to_tag_and_arg
| Enum -> { tag | String, arg | optional }
| doc m%"
Convert an enum to record with a string tag and an optional argument. If
the enum is an enum tag, the `arg` field is simply omitted.

`std.enum.from_tag_and_arg` provides the inverse transformation,
reconstructing an enum from a string tag and an argument.

# Examples

```nickel
std.enum.to_tag_and_arg ('Foo "arg") =>
{ tag = "Foo", arg = "arg" }
std.enum.to_tag_and_arg 'http
=> { tag = "http" }
```
"%
= fun enum_value =>
let tag_string = %to_string% (%enum/get_tag% enum_value) in
if %enum/is_variant% enum_value then
{
tag = tag_string,
arg = %enum/get_arg% enum_value,
}
else
{ tag = tag_string },

from_tag_and_arg
| { tag | String, arg | optional } -> Enum
| doc m%"
Create an enum from a string tag and an optional argument. If the `arg`
field is omitted, a bare enum tag is created.

`std.enum.to_tag_and_value` provides the inverse transformation,
extracting a string tag and an argument from an enum.

# Examples

```nickel
std.enum.from_tag_and_arg { tag = "Foo", arg = "arg" }
=> ('Foo "arg")
std.enum.from_tag_and_arg { tag = "http" }
=> 'http
```
"%
= fun enum_data =>
if %record/has_field% "arg" enum_data then
%enum/make_variant% enum_data.tag enum_data.arg
else
%enum/from_string% enum_data.tag,

map
| (Dyn -> Dyn) -> Enum -> Enum
| doc m%"
Maps a function over an enum variant's argument. If the enum doesn't
have an argument, it is left unchanged.

# Examples

```nickel
std.enum.map ((+) 1) ('Foo 42)
=> 'Foo 43
std.enum.map f 'Bar
=> 'Bar
```
"%
= fun f enum_value =>
if %enum/is_variant% enum_value then
let tag = (%to_string% (%enum/get_tag% enum_value)) in
let mapped = f (%enum/get_arg% enum_value) in

%enum/make_variant% tag mapped
else
enum_value,
},

function = {
Expand Down
2 changes: 1 addition & 1 deletion core/tests/integration/inputs/adts/enum_primops.ncl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


[
%enum/unwrap_variant% ('Left (1+1)) == 2,
%enum/get_arg% ('Left (1+1)) == 2,
!(%enum/is_variant% 'Right),
%enum/is_variant% ('Right 1),
%enum/get_tag% 'Right == 'Right,
Expand Down
44 changes: 44 additions & 0 deletions core/tests/integration/inputs/stdlib/enum.ncl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# test.type = 'pass'
let enum = std.enum in

[
enum.is_enum_tag 'A,
!(enum.is_enum_tag ('A 'arg)),
enum.is_enum_variant ('A 'arg),
!enum.is_enum_variant 'A,

let enum_round_trip = fun enum_value =>
enum_value
|> enum.to_tag_and_arg
|> enum.from_tag_and_arg
|> (==) enum_value
in
[
enum_round_trip 'Foo,
enum_round_trip ('Foo 'arg),
enum_round_trip ('Foo { value = "hello" }),
]
|> std.test.assert_all,

let record_round_trip = fun data =>
data
|> enum.from_tag_and_arg
|> enum.to_tag_and_arg
|> (==) data
in
[
record_round_trip { tag = "Foo" },
record_round_trip { tag = "Foo", arg = "arg" },
record_round_trip { tag = "Foo", arg = { value = "hello" } },
]
|> std.test.assert_all,

'Foo
|> std.enum.map (fun _ => null)
|> (==) 'Foo,

'Foo 2
|> std.enum.map ((*) 2)
|> (==) ('Foo 4),
]
|> std.test.assert_all

0 comments on commit 95967eb

Please sign in to comment.