Skip to content

Commit

Permalink
Statically enforce primitive string size invariant.
Browse files Browse the repository at this point in the history
Otherwise it is too easy to produce a string violating the size
constraints (e.g., using Ltac2), and prove False.

This also adds an Ltac2 API to the Pstring module.
  • Loading branch information
rlepigre committed Jun 13, 2024
1 parent 35c23c9 commit a5d9d6c
Show file tree
Hide file tree
Showing 47 changed files with 158 additions and 71 deletions.
4 changes: 2 additions & 2 deletions dev/top_printers.ml
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ let constr_display csr =
| Float f ->
"Float("^(Float64.to_string f)^")"
| String s ->
Printf.sprintf "String(%S)" s
Printf.sprintf "String(%S)" (Pstring.to_string s)
| Array (u,t,def,ty) -> "Array("^(array_display t)^","^(term_display def)^","^(term_display ty)^")@{" ^universes_display u^"\n"

and array_display v =
Expand Down Expand Up @@ -530,7 +530,7 @@ let print_pure_constr csr =
| Float f ->
print_string ("Float("^(Float64.to_string f)^")")
| String s ->
print_string (Printf.sprintf "String(%S)" s)
print_string (Printf.sprintf "String(%S)" (Pstring.to_string s))
| Array (u,t,def,ty) ->
print_string "Array(";
Array.iter (fun x -> box_display x; print_space()) t;
Expand Down
2 changes: 1 addition & 1 deletion dev/vm_printers.ml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ and ppwhd whd =
| Vblock b -> ppvblock b
| Vint64 i -> printf "int64(%LiL)" i
| Vfloat64 f -> printf "float64(%.17g)" f
| Vstring s -> printf "string(%S)" s
| Vstring s -> printf "string(%S)" (Pstring.to_string s)
| Varray t -> ppvarray t
| Vaccu (a, s) ->
open_hbox();ppatom a;close_box();
Expand Down
1 change: 1 addition & 0 deletions doc/stdlib/index-list.html.template
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ through the <tt>Require Import</tt> command.</p>
user-contrib/Ltac2/Pattern.v
user-contrib/Ltac2/Printf.v
user-contrib/Ltac2/Proj.v
user-contrib/Ltac2/Pstring.v
user-contrib/Ltac2/RedFlags.v
user-contrib/Ltac2/Ref.v
user-contrib/Ltac2/Std.v
Expand Down
2 changes: 1 addition & 1 deletion engine/eConstr.mli
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ val mkArrow : t -> ERelevance.t -> t -> t
val mkArrowR : t -> t -> t
val mkInt : Uint63.t -> t
val mkFloat : Float64.t -> t
val mkString : String.t -> t
val mkString : Pstring.t -> t
val mkArray : EInstance.t * t array * t * t -> t

module UnsafeMonomorphic : sig
Expand Down
2 changes: 1 addition & 1 deletion interp/constrextern.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,7 @@ let rec extern inctx scopes vars r =

| GString s ->
extern_prim_token_delimiter_if_required
(String s)
(String (Pstring.to_string s))
"pstring" "pstring_scope" (snd scopes)

| GArray(u,t,def,ty) ->
Expand Down
9 changes: 4 additions & 5 deletions interp/notation.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1133,10 +1133,9 @@ let coqbyte_of_string ?loc esig byte s =
let coqbyte_of_char esig byte c = coqbyte_of_char_code esig byte (Char.code c)

let pstring_of_string ?loc s =
if String.length s > Pstring.max_length_int then
user_err ?loc (str "String literal would be too large on a 32-bits system.")
else
Constr.mkString s
match Pstring.of_string s with
| Some s -> Constr.mkString s
| None -> user_err ?loc (str "String literal would be too large on a 32-bits system.")

let make_ascii_string n =
if n>=32 && n<=126 then String.make 1 (char_of_int n)
Expand All @@ -1150,7 +1149,7 @@ let string_of_coqbyte c = make_ascii_string (char_code_of_coqbyte c)

let string_of_pstring c =
match TokenValue.kind c with
| TString s -> s
| TString s -> Pstring.to_string s
| _ -> raise NotAValidPrimToken

let coqlist_byte_of_string esig byte_ty list_ty str =
Expand Down
2 changes: 1 addition & 1 deletion interp/notation_ops.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1575,7 +1575,7 @@ let rec match_ inner u alp metas sigma a1 a2 =
| GSort s1, NSort s2 when glob_sort_eq s1 s2 -> sigma
| GInt i1, NInt i2 when Uint63.equal i1 i2 -> sigma
| GFloat f1, NFloat f2 when Float64.equal f1 f2 -> sigma
| GString s1, NString s2 when String.equal s1 s2 -> sigma
| GString s1, NString s2 when Pstring.equal s1 s2 -> sigma
| GPatVar _, NHole _ -> (*Don't hide Metas, they bind in ltac*) raise No_match
| a, NHole _ -> sigma

Expand Down
2 changes: 1 addition & 1 deletion kernel/cClosure.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1745,7 +1745,7 @@ and match_head : 'a. ('a, 'a patstate) reduction -> _ -> _ -> pat_state:(fconstr
| FString s' ->
let elims, states = extract_or_kill2 (function [@ocaml.warning "-4"]
| (PHString s, elims), psubst ->
if not @@ String.equal s s' then None else
if not @@ Pstring.equal s s' then None else
Some (elims, psubst)
| _ -> None) patterns states
in
Expand Down
10 changes: 5 additions & 5 deletions kernel/constr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ type ('constr, 'types, 'sort, 'univs, 'r) kind_of_term =
| Proj of Projection.t * 'r * 'constr
| Int of Uint63.t
| Float of Float64.t
| String of String.t
| String of Pstring.t
| Array of 'univs * 'constr array * 'constr * 'types

(* constr is the fixpoint of the previous type. *)
Expand Down Expand Up @@ -907,7 +907,7 @@ let compare_head_gen_leq_with kind1 kind2 leq_universes leq_sorts eq_evars eq le
| Var id1, Var id2 -> Id.equal id1 id2
| Int i1, Int i2 -> Uint63.equal i1 i2
| Float f1, Float f2 -> Float64.equal f1 f2
| String s1, String s2 -> String.equal s1 s2
| String s1, String s2 -> Pstring.equal s1 s2
| Sort s1, Sort s2 -> leq_sorts s1 s2
| Prod (_,t1,c1), Prod (_,t2,c2) -> eq 0 t1 t2 && leq 0 c1 c2
| Lambda (_,t1,c1), Lambda (_,t2,c2) -> eq 0 t1 t2 && eq 0 c1 c2
Expand Down Expand Up @@ -1079,7 +1079,7 @@ let constr_ord_int f t1 t2 =
| Int _, _ -> -1 | _, Int _ -> 1
| Float f1, Float f2 -> Float64.total_compare f1 f2
| Float _, _ -> -1 | _, Float _ -> 1
| String s1, String s2 -> String.compare s1 s2
| String s1, String s2 -> Pstring.compare s1 s2
| String _, _ -> -1 | _, String _ -> 1
| Array(_u1,t1,def1,ty1), Array(_u2,t2,def2,ty2) ->
compare [(Array.compare f, t1, t2); (f, def1, def2); (f, ty1, ty2)]
Expand Down Expand Up @@ -1182,7 +1182,7 @@ let hasheq t1 t2 =
&& array_eqeq bl1 bl2
| Int i1, Int i2 -> i1 == i2
| Float f1, Float f2 -> Float64.equal f1 f2
| String s1, String s2 -> String.equal s1 s2
| String s1, String s2 -> Pstring.equal s1 s2
| Array(u1,t1,def1,ty1), Array(u2,t2,def2,ty2) ->
u1 == u2 && def1 == def2 && ty1 == ty2 && array_eqeq t1 t2
| (Rel _ | Meta _ | Var _ | Sort _ | Cast _ | Prod _ | Lambda _ | LetIn _
Expand Down Expand Up @@ -1543,7 +1543,7 @@ let rec debug_print c =
str"}")
| Int i -> str"Int("++str (Uint63.to_string i) ++ str")"
| Float i -> str"Float("++str (Float64.to_string i) ++ str")"
| String s -> str"String("++str (Printf.sprintf "%S" s) ++ str")"
| String s -> str"String("++str (Printf.sprintf "%S" (Pstring.to_string s)) ++ str")"
| Array(u,t,def,ty) -> str"Array(" ++ prlist_with_sep pr_comma debug_print (Array.to_list t) ++ str" | "
++ debug_print def ++ str " : " ++ debug_print ty
++ str")@{" ++ UVars.Instance.pr Sorts.QVar.raw_pr Univ.Level.raw_pr u ++ str"}"
Expand Down
4 changes: 2 additions & 2 deletions kernel/constr.mli
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ val mkArray : UVars.Instance.t * constr array * constr * types -> constr
val mkFloat : Float64.t -> constr

(** Constructs a machine string. *)
val mkString : string -> constr
val mkString : Pstring.t -> constr

(** Constructs an patvar named "?n" *)
val mkMeta : metavariable -> constr
Expand Down Expand Up @@ -287,7 +287,7 @@ type ('constr, 'types, 'sort, 'univs, 'r) kind_of_term =
(** The relevance is the relevance of the whole term *)
| Int of Uint63.t
| Float of Float64.t
| String of String.t
| String of Pstring.t
| Array of 'univs * 'constr array * 'constr * 'types
(** [Array (u,vals,def,t)] is an array of [vals] in type [t] with default value [def].
[u] is a universe containing [t]. *)
Expand Down
4 changes: 2 additions & 2 deletions kernel/conversion.ml
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ let rec compare_under e1 c1 e2 c2 =
| Var id1, Var id2 -> Id.equal id1 id2
| Int i1, Int i2 -> Uint63.equal i1 i2
| Float f1, Float f2 -> Float64.equal f1 f2
| String s1, String s2 -> String.equal s1 s2
| String s1, String s2 -> Pstring.equal s1 s2
| Sort s1, Sort s2 ->
let subst_instance_sort u s =
if UVars.Instance.is_empty u then s else UVars.subst_instance_sort u s
Expand Down Expand Up @@ -706,7 +706,7 @@ and eqappr cv_pb l2r infos (lft1,st1) (lft2,st2) cuniv =
else raise NotConvertible

| FString s1, FString s2 ->
if String.equal s1 s2 then convert_stacks l2r infos lft1 lft2 v1 v2 cuniv
if Pstring.equal s1 s2 then convert_stacks l2r infos lft1 lft2 v1 v2 cuniv
else raise NotConvertible

| FCaseInvert (ci1,u1,pms1,p1,iv1,_,br1,e1), FCaseInvert (ci2,u2,pms2,p2,iv2,_,br2,e2) ->
Expand Down
4 changes: 2 additions & 2 deletions kernel/genlambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type 'v lambda =
(* inductive name, constructor tag, arguments *)
| Luint of Uint63.t
| Lfloat of Float64.t
| Lstring of String.t
| Lstring of Pstring.t
| Lval of 'v
| Lsort of Sorts.t
| Lind of pinductive
Expand Down Expand Up @@ -162,7 +162,7 @@ let rec pp_lam lam =
str")")
| Luint i -> str (Uint63.to_string i)
| Lfloat f -> str (Float64.to_string f)
| Lstring s -> str (Printf.sprintf "%S" s)
| Lstring s -> str (Printf.sprintf "%S" (Pstring.to_string s))
| Lval _ -> str "values"
| Lsort s -> pp_sort s
| Lind ((mind,i), _) -> MutInd.print mind ++ str"#" ++ int i
Expand Down
2 changes: 1 addition & 1 deletion kernel/genlambda.mli
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type 'v lambda =
(* inductive name, constructor tag, arguments *)
| Luint of Uint63.t
| Lfloat of Float64.t
| Lstring of String.t
| Lstring of Pstring.t
| Lval of 'v
| Lsort of Sorts.t
| Lind of pinductive
Expand Down
6 changes: 3 additions & 3 deletions kernel/nativecode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ type mllambda =
| MLint of int
| MLuint of Uint63.t
| MLfloat of Float64.t
| MLstring of String.t
| MLstring of Pstring.t
| MLsetref of string * mllambda
| MLsequence of mllambda * mllambda
| MLarray of mllambda array
Expand Down Expand Up @@ -570,7 +570,7 @@ let rec eq_mllambda gn1 gn2 n env1 env2 t1 t2 =
| MLfloat f1, MLfloat f2 ->
Float64.equal f1 f2
| MLstring s1, MLstring s2 ->
String.equal s1 s2
Pstring.equal s1 s2
| MLsetref (id1, ml1), MLsetref (id2, ml2) ->
String.equal id1 id2 &&
eq_mllambda gn1 gn2 n env1 env2 ml1 ml2
Expand Down Expand Up @@ -1832,7 +1832,7 @@ let pp_mllam fmt l =
| MLint i -> pp_int fmt i
| MLuint i -> Format.fprintf fmt "(%s)" (Uint63.compile i)
| MLfloat f -> Format.fprintf fmt "(%s)" (Float64.compile f)
| MLstring s -> Format.fprintf fmt "%S" s
| MLstring s -> Format.fprintf fmt "(%s)" (Pstring.compile s)
| MLsetref (s, body) ->
Format.fprintf fmt "@[%s@ :=@\n %a@]" s pp_mllam body
| MLsequence(l1,l2) ->
Expand Down
2 changes: 1 addition & 1 deletion kernel/nativeconv.ml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ let rec conv_val env pb lvl v1 v2 cu =
if Float64.(equal (of_float f1) (of_float f2)) then cu
else raise NotConvertible
| Vstring s1, Vstring s2 ->
if String.equal s1 s2 then cu
if Pstring.equal s1 s2 then cu
else raise NotConvertible
| Varray t1, Varray t2 ->
let len = Parray.length_int t1 in
Expand Down
2 changes: 1 addition & 1 deletion kernel/nativevalues.mli
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ val mk_uint : Uint63.t -> t

val mk_float : Float64.t -> t

val mk_string : string -> t
val mk_string : Pstring.t -> t

val napply : t -> t array -> t
(* Functions over accumulators *)
Expand Down
4 changes: 2 additions & 2 deletions kernel/primred.mli
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ module type RedNativeEntries =
val get : args -> int -> elem
val get_int : evd -> elem -> Uint63.t
val get_float : evd -> elem -> Float64.t
val get_string : evd -> elem -> String.t
val get_string : evd -> elem -> Pstring.t
val get_parray : evd -> elem -> elem Parray.t
val mkInt : env -> Uint63.t -> elem
val mkFloat : env -> Float64.t -> elem
val mkString : env -> String.t -> elem
val mkString : env -> Pstring.t -> elem
val mkBool : env -> bool -> elem
val mkCarry : env -> bool -> elem -> elem (* true if carry *)
val mkIntPair : env -> elem -> elem -> elem
Expand Down
12 changes: 11 additions & 1 deletion kernel/pstring.ml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ let max_length_int : int = 16777211

let max_length : Uint63.t = Uint63.of_int max_length_int

let to_string : t -> string = fun s -> s

let of_string : string -> t option = fun s ->
if String.length s <= max_length_int then Some s else None

(* Return a string of size [max_length] if the parameter is too large.
Use [c land 255] if [c] is not a valid character. *)
let make : Uint63.t -> char63 -> t = fun i c ->
Expand All @@ -29,7 +34,7 @@ let length : t -> Uint63.t = fun s ->
Uint63.of_int (String.length s)

(* Out of bound access gives '\x00'. *)
let get : string -> Uint63.t -> char63 = fun s i ->
let get : t -> Uint63.t -> char63 = fun s i ->
let i = Uint63.to_int_min i max_length_int in
if i < String.length s then
Uint63.of_int (Char.code (String.get s i))
Expand Down Expand Up @@ -66,3 +71,8 @@ let equal : t -> t -> bool =

let hash : t -> int =
Hashtbl.hash

let unsafe_of_string : string -> t = fun s -> s

let compile : t -> string =
Printf.sprintf "Pstring.unsafe_of_string %S"
19 changes: 18 additions & 1 deletion kernel/pstring.mli
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,20 @@

(** Primitive [string] type. *)

type t = string
type t = private string

type char63 = Uint63.t

val max_length_int : int
val max_length : Uint63.t

(** [to_string s] converts the primitive string [s] into a standard string. *)
val to_string : t -> string

(** [of_string s] converts string [s] into a primitive string if its size does
not exceed [max_length_int], and returns [None] otherwise. *)
val of_string : string -> t option

(** [make i c] returns a string of size [min i max_length] containing only the
character with code [c l_and 255]. *)
val make : Uint63.t -> char63 -> t
Expand Down Expand Up @@ -49,3 +56,13 @@ val equal : t -> t -> bool

(** [hash s] gives a hash of [s], with the same value as [Hashtbl.hash s]. *)
val hash : t -> int

(** [unsafe_of_string s] converts [s] into a primitive string without checking
whether its size satisfies the length constraint. Callers of this function
must ensure that [length s <= max_length_int], which is always the case if
[s] was obtained via [to_string]. NOTE: this function is used in generated
code, via [compile]. *)
val unsafe_of_string : string -> t

(** [compile s] outputs an OCaml expression producing primitive string [s]. *)
val compile : t -> string
2 changes: 1 addition & 1 deletion kernel/typeops.mli
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ val type_of_float : env -> types
val judge_of_float : env -> Float64.t -> unsafe_judgment

val type_of_string : env -> types
val judge_of_string : env -> String.t -> unsafe_judgment
val judge_of_string : env -> Pstring.t -> unsafe_judgment

val type_of_array : env -> UVars.Instance.t -> types
val judge_of_array : env -> UVars.Instance.t -> unsafe_judgment array -> unsafe_judgment -> unsafe_judgment
Expand Down
2 changes: 1 addition & 1 deletion kernel/values.mli
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ type ('value, 'vaccu, 'vfun, 'vprod, 'vfix, 'vcofix, 'vblock) kind =
| Vblock of 'vblock
| Vint64 of int64
| Vfloat64 of float
| Vstring of string
| Vstring of Pstring.t
| Varray of 'value Parray.t
2 changes: 1 addition & 1 deletion kernel/vconv.ml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ and conv_whd env pb k whd1 whd2 cu =
if Float64.(equal (of_float f1) (of_float f2)) then cu
else raise NotConvertible
| Vstring s1, Vstring s2 ->
if String.equal s1 s2 then cu
if Pstring.equal s1 s2 then cu
else raise NotConvertible
| Varray t1, Varray t2 ->
if t1 == t2 then cu else
Expand Down
2 changes: 1 addition & 1 deletion kernel/vmemitcodes.ml
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ type t =
| SReloc_Const_val of structured_values
| SReloc_Const_uint of Uint63.t
| SReloc_Const_float of Float64.t
| SReloc_Const_string of String.t
| SReloc_Const_string of Pstring.t
| SReloc_annot of annot_switch
| SReloc_caml_prim of caml_prim

Expand Down
4 changes: 2 additions & 2 deletions kernel/vmvalues.ml
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ let pp_struct_const = function
| Const_val _ -> Pp.str "(value)"
| Const_uint i -> Pp.str (Uint63.to_string i)
| Const_float f -> Pp.str (Float64.to_string f)
| Const_string s -> Pp.str (Printf.sprintf "%S" s)
| Const_string s -> Pp.str (Printf.sprintf "%S" (Pstring.to_string s))

(* Abstract data *)
type vprod
Expand Down Expand Up @@ -659,7 +659,7 @@ and pr_kind w =
| Vblock _b -> str "Vblock"
| Vint64 i -> i |> Format.sprintf "Vint64(%LiL)" |> str
| Vfloat64 f -> str "Vfloat64(" ++ str (Float64.(to_string (of_float f))) ++ str ")"
| Vstring s -> s |> Format.sprintf "Vstring(%S)" |> str
| Vstring s -> Pstring.to_string s |> Format.sprintf "Vstring(%S)" |> str
| Varray _ -> str "Varray"
| Vaccu (a, stk) -> str "Vaccu(" ++ pr_atom a ++ str ", " ++ pr_stack stk ++ str ")"
and pr_stack stk =
Expand Down
2 changes: 1 addition & 1 deletion plugins/extraction/json.ml
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ let rec json_expr env = function
]
| MLstring s -> json_dict [
("what", json_str "expr:string");
("string", json_str s)
("string", json_str (Pstring.to_string s))
]
| MLparray(t,def) -> json_dict [
("what", json_str "expr:array");
Expand Down
2 changes: 1 addition & 1 deletion plugins/extraction/miniml.ml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ and ml_ast =
| MLmagic of ml_ast
| MLuint of Uint63.t
| MLfloat of Float64.t
| MLstring of String.t
| MLstring of Pstring.t
| MLparray of ml_ast array * ml_ast

and ml_pattern =
Expand Down
Loading

0 comments on commit a5d9d6c

Please sign in to comment.