Skip to content

Commit

Permalink
fix: cancellation
Browse files Browse the repository at this point in the history
* Add cancellation mechanism to jsonrpc
* Add cancellable requests to lsp
* Add tests for both

Signed-off-by: Rudi Grinberg <[email protected]>

ps-id: 9A5AEFA0-DAA0-49F9-9B9B-C2FB6AB20BF5
  • Loading branch information
tatchi authored and rgrinberg committed Jun 13, 2022
1 parent 3bfb3c3 commit 398c31e
Show file tree
Hide file tree
Showing 11 changed files with 333 additions and 107 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Features

- Fix cancellation mechanism for all requests (#707)

- Allow cancellation of formatting requests (#707)

- Add `--fallback-read-dot-merlin` to the LSP Server (#705). If `ocamllsp` is
Expand Down
36 changes: 32 additions & 4 deletions jsonrpc-fiber/src/jsonrpc_fiber.ml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ struct
; on_request : ('state, Request.t) context -> (Reply.t * 'state) Fiber.t
; on_notification :
('state, Notification.t) context -> (Notify.t * 'state) Fiber.t
; pending : (Response.t, [ `Stopped ]) result Fiber.Ivar.t Id.Table.t
; pending :
(Response.t, [ `Stopped | `Cancelled ]) result Fiber.Ivar.t Id.Table.t
; stopped : unit Fiber.Ivar.t
; name : string
; mutable running : bool
Expand All @@ -81,6 +82,10 @@ struct

and ('a, 'message) context = 'a t * 'message

type cancel = unit Fiber.t

let fire cancel = cancel

module Context = struct
type nonrec ('a, 'id) t = ('a, 'id) context

Expand Down Expand Up @@ -190,10 +195,13 @@ struct
| None ->
log "dropped";
Fiber.return ()
| Some ivar ->
| Some ivar -> (
log "acknowledged";
Id.Table.remove t.pending r.id;
Fiber.Ivar.fill ivar (Ok r)
let* resp = Fiber.Ivar.peek ivar in
match resp with
| Some _ -> Fiber.return ()
| None -> Fiber.Ivar.fill ivar (Ok r))
and on_request (r : Request.t) =
let* result =
let sent = ref false in
Expand Down Expand Up @@ -276,6 +284,7 @@ struct
let+ res = Fiber.Ivar.read ivar in
match res with
| Ok s -> s
| Error `Cancelled -> assert false
| Error `Stopped -> raise (Stopped req)

let request t (req : Request.t) =
Expand All @@ -286,9 +295,28 @@ struct
register_request_ivar t req.id ivar;
read_request_ivar req ivar)

let request_with_cancel t (req : Request.t) =
let ivar = Fiber.Ivar.create () in
let cancel = Fiber.Ivar.fill ivar (Error `Cancelled) in
let resp =
Fiber.of_thunk (fun () ->
check_running t;
let* () =
let+ () = Chan.send t.chan [ Request req ] in
register_request_ivar t req.id ivar
in
let+ res = Fiber.Ivar.read ivar in
match res with
| Ok s -> `Ok s
| Error `Cancelled -> `Cancelled
| Error `Stopped -> raise (Stopped req))
in
(cancel, resp)

module Batch = struct
type response =
Jsonrpc.Request.t * (Jsonrpc.Response.t, [ `Stopped ]) result Fiber.Ivar.t
Jsonrpc.Request.t
* (Jsonrpc.Response.t, [ `Stopped | `Cancelled ]) result Fiber.Ivar.t

type t = [ `Notification of Notification.t | `Request of response ] list ref

Expand Down
9 changes: 9 additions & 0 deletions jsonrpc-fiber/src/jsonrpc_fiber.mli
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ end) : sig

val request : _ t -> Jsonrpc.Request.t -> Jsonrpc.Response.t Fiber.t

type cancel

val fire : cancel -> unit Fiber.t

val request_with_cancel :
_ t
-> Jsonrpc.Request.t
-> cancel * [ `Ok of Jsonrpc.Response.t | `Cancelled ] Fiber.t

module Batch : sig
type t

Expand Down
96 changes: 95 additions & 1 deletion jsonrpc-fiber/test/jsonrpc_fiber_tests.ml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
open Stdune
open! Jsonrpc
open Jsonrpc
open Jsonrpc_fiber
open Fiber.O
open Fiber.Stream
Expand Down Expand Up @@ -263,3 +263,97 @@ let%expect_test "test from jsonrpc_test.ml" =
<opaque>
{ "id": 10, "jsonrpc": "2.0", "result": 1 }
{ "id": "testing", "jsonrpc": "2.0", "result": 2 } |}]

let%expect_test "cancellation" =
let print packet =
print_endline
(Yojson.Safe.pretty_to_string ~std:false
(Jsonrpc.Packet.yojson_of_t packet))
in
let server_req_ack = Fiber.Ivar.create () in
let client_req_ack = Fiber.Ivar.create () in
let server chan =
let on_request c =
let request = Context.message c in
let state = Context.state c in
print_endline "server: received request";
print (Request request);
let* () = Fiber.Ivar.fill server_req_ack () in
let response =
Reply.later (fun send ->
print_endline
"server: waiting for client ack before sending response";
let* () = Fiber.Ivar.read client_req_ack in
print_endline "server: got client ack, sending response";
send (Jsonrpc.Response.ok request.id (`String "Ok")))
in
Fiber.return (response, state)
in
Jrpc.create ~name:"server" ~on_request chan ()
in
let client chan = Jrpc.create ~name:"client" chan () in
let responses = ref [] in
let run () =
let pool = Fiber.Pool.create () in
let client_in, _ = pipe () in
let server_in, client_out = pipe () in
let out = of_ref responses in
let client = client (client_in, client_out) in
let server = server (server_in, out) in
let request =
Jsonrpc.Request.create ~id:(`String "initial") ~method_:"init" ()
in
let cancel, req = Jrpc.request_with_cancel client request in
let* () =
Fiber.Pool.task pool ~f:(fun () ->
print_endline
"client: waiting for server ack before cancelling request";
let* () = Fiber.Ivar.read server_req_ack in
print_endline "client: got server ack, cancelling request";
let* () = Jrpc.fire cancel in
Fiber.Ivar.fill client_req_ack ())
in
let initial_request () =
print_endline "client: sending request";
let* resp = req in
(match resp with
| `Cancelled -> print_endline "request has been cancelled"
| `Ok resp ->
print_endline "request response:";
print (Response resp));
Fiber.return ()
in
let all =
Fiber.all_concurrently
[ Fiber.Pool.run pool
; Jrpc.run client
; initial_request ()
; Jrpc.run server
]
in
Fiber.fork_and_join_unit
(fun () ->
Fiber.fork_and_join_unit
(fun () -> Jrpc.stopped client)
(fun () -> Jrpc.stopped server))
(fun () -> all)
in
Fiber_test.test Dyn.opaque run;
(* Ensure that server still responds even if the request was cancelled.
Required by the lsp spec *)
List.rev !responses
|> List.iter ~f:(fun packet ->
let json = Jsonrpc.Packet.yojson_of_t packet in
print_json json);
[%expect
{|
client: sending request
client: waiting for server ack before cancelling request
server: received request
{ "id": "initial", "method": "init", "jsonrpc": "2.0" }
server: waiting for client ack before sending response
client: got server ack, cancelling request
request has been cancelled
server: got client ack, sending response
[FAIL] unexpected Never raised
{ "id": "initial", "jsonrpc": "2.0", "result": "Ok" } |}]
109 changes: 60 additions & 49 deletions lsp-fiber/src/rpc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -12,44 +12,9 @@ module Reply = struct
let now r = Now r

let later f = Later f

let to_jsonrpc t id to_json : Jsonrpc_fiber.Reply.t =
let f x = Jsonrpc.Response.ok id (to_json x) in
match t with
| Now r -> Jsonrpc_fiber.Reply.now (f r)
| Later k -> Jsonrpc_fiber.Reply.later (fun send -> k (fun r -> send (f r)))
end

module Cancel = struct
type state =
| Pending of { mutable callbacks : (unit -> unit Fiber.t) list }
| Finished

type t = state ref

let var = Fiber.Var.create ()

let register f =
let+ cancel = Fiber.Var.get var in
match cancel with
| None -> ()
| Some cancel -> (
match !cancel with
| Finished -> ()
| Pending p -> p.callbacks <- f :: p.callbacks)

let create () = ref (Pending { callbacks = [] })

let destroy t = t := Finished

let cancel t =
Fiber.of_thunk (fun () ->
match !t with
| Finished -> Fiber.return ()
| Pending { callbacks } ->
t := Finished;
Fiber.parallel_iter callbacks ~f:(fun f -> f ()))
end
let cancel_token = Fiber.Var.create ()

module State = struct
type t =
Expand Down Expand Up @@ -97,7 +62,7 @@ module type S = sig

val notification : _ t -> out_notification -> unit Fiber.t

val on_cancel : (unit -> unit Fiber.t) -> unit Fiber.t
val cancel_token : unit -> Fiber.Cancel.t option Fiber.t

module Batch : sig
type t
Expand Down Expand Up @@ -168,7 +133,7 @@ struct
; (* Filled when the server is initialied *)
initialized : Initialize.t Fiber.Ivar.t
; mutable req_id : int
; pending : Cancel.t Table.t
; pending : Fiber.Cancel.t Table.t
; detached : Fiber.Pool.t
}

Expand Down Expand Up @@ -219,20 +184,38 @@ struct
Fiber.return
(Jsonrpc_fiber.Reply.now (Jsonrpc.Response.error req.id error), state)
| Ok (In_request.E r) ->
let cancel = Cancel.create () in
Table.replace t.pending req.id cancel;
let cancel = Fiber.Cancel.create () in
let remove = lazy (Table.remove t.pending req.id) in
let+ response, state =
Fiber.finalize
Fiber.with_error_handler
~on_error:
(Stdune.Exn_with_backtrace.map_and_reraise ~f:(fun exn ->
Lazy.force remove;
exn))
(fun () ->
Fiber.Var.set Cancel.var cancel (fun () ->
Fiber.Var.set cancel_token cancel (fun () ->
Table.replace t.pending req.id cancel;
h_on_request.on_request t r))
~finally:(fun () ->
Cancel.destroy cancel;
Table.remove t.pending req.id;
Fiber.return ())
in
let to_response x =
Jsonrpc.Response.ok req.id (In_request.yojson_of_result r x)
in
let reply =
Reply.to_jsonrpc response req.id (In_request.yojson_of_result r)
match response with
| Reply.Now r ->
Lazy.force remove;
Jsonrpc_fiber.Reply.now (to_response r)
| Reply.Later k ->
let f send =
Fiber.finalize
(fun () ->
Fiber.Var.set cancel_token cancel (fun () ->
k (fun r -> send (to_response r))))
~finally:(fun () ->
Lazy.force remove;
Fiber.return ())
in
Jsonrpc_fiber.Reply.later f
in
(reply, state)
in
Expand Down Expand Up @@ -286,6 +269,29 @@ struct
in
receive_response req resp)

let request_with_cancel (type r) (t : _ t) cancel ~on_cancel
(req : r Out_request.t) : [ `Ok of r | `Cancelled ] Fiber.t =
let* () = Fiber.return () in
let jsonrpc_req = create_request t req in
let+ resp, cancel_status =
Fiber.Cancel.with_handler cancel
~on_cancel:(fun () -> on_cancel jsonrpc_req.id)
(fun () ->
let _, req_f =
Session.request_with_cancel (Fdecl.get t.session) jsonrpc_req
in
let+ resp = req_f in
match resp with
| `Cancelled -> `Cancelled
| `Ok resp -> `Ok (receive_response req resp))
in
match cancel_status with
| Cancelled () -> `Cancelled
| Not_cancelled -> (
match resp with
| `Ok resp -> `Ok resp
| `Cancelled -> assert false)

let notification (t : _ t) (n : Out_notification.t) : unit Fiber.t =
let jsonrpc_request = Out_notification.to_jsonrpc n in
Session.notification (Fdecl.get t.session) jsonrpc_request
Expand Down Expand Up @@ -341,11 +347,12 @@ struct
let+ () =
match Table.find_opt t.pending id with
| None -> Fiber.return ()
| Some id -> Fiber.Pool.task t.detached ~f:(fun () -> Cancel.cancel id)
| Some token ->
Fiber.Pool.task t.detached ~f:(fun () -> Fiber.Cancel.fire token)
in
(Jsonrpc_fiber.Notify.Continue, state t)

let on_cancel = Cancel.register
let cancel_token () = Fiber.Var.get cancel_token
end

module Client = struct
Expand All @@ -366,6 +373,10 @@ module Client = struct
let h_on_notification = h_on_notification handler in
make ~name:"client" handler.h_on_request h_on_notification io

let request_with_cancel t cancel r =
request_with_cancel t cancel r ~on_cancel:(fun id ->
notification t (Client_notification.CancelRequest id))

let start (t : _ t) (p : InitializeParams.t) =
Fiber.of_thunk (fun () ->
assert (t.state = Waiting_for_init);
Expand Down
9 changes: 8 additions & 1 deletion lsp-fiber/src/rpc.mli
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ module type S = sig

val notification : _ t -> out_notification -> unit Fiber.t

val on_cancel : (unit -> unit Fiber.t) -> unit Fiber.t
(** only available inside requests *)
val cancel_token : unit -> Fiber.Cancel.t option Fiber.t

module Batch : sig
type t
Expand Down Expand Up @@ -81,6 +82,12 @@ module Client : sig
and type 'a in_request = 'a Server_request.t
and type in_notification = Server_notification.t

val request_with_cancel :
_ t
-> Fiber.Cancel.t
-> 'resp out_request
-> [ `Ok of 'resp | `Cancelled ] Fiber.t

val initialized : _ t -> InitializeResult.t Fiber.t

val start : _ t -> InitializeParams.t -> unit Fiber.t
Expand Down
Loading

0 comments on commit 398c31e

Please sign in to comment.