diff --git a/beluga.opam b/beluga.opam index 30f57b4c7..ecb519652 100644 --- a/beluga.opam +++ b/beluga.opam @@ -25,7 +25,7 @@ depends: [ "omd" {>= "1.3.2"} "uri" {>= "4.2.0"} "ocamlformat" {= "0.25.1" & with-test} - "yojson" {>= "2.0.2" & with-test} + "yojson" {>= "2.0.2"} "ounit2" {>= "2.2.6" & with-test} "bisect_ppx" {>= "2.8.1" & with-test} "odoc" {>= "2.2.0" & with-doc} diff --git a/dune-project b/dune-project index cf3873a0e..c17f4b155 100644 --- a/dune-project +++ b/dune-project @@ -49,9 +49,7 @@ (= 0.25.1) :with-test)) (yojson - (and - (>= 2.0.2) - :with-test)) + (>= 2.0.2)) (ounit2 (and (>= 2.2.6) diff --git a/src/beluga/callGraph.ml b/src/beluga/callGraph.ml index ba36f0565..4579e5165 100644 --- a/src/beluga/callGraph.ml +++ b/src/beluga/callGraph.ml @@ -1,4 +1,4 @@ -[@@@warning "+A-4-44"] +[@@@warning "+A-4-32-44"] open! Support open Beluga @@ -27,10 +27,13 @@ module CallsRecord : sig val add_program_call : t -> Id.cid_prog -> unit - val has_program_call : t -> Id.cid_prog -> bool [@@warning "-32"] + val has_program_call : t -> Id.cid_prog -> bool (** [iter f r] iteratively applies [f] over the program calls added to [r]. *) val iter : (Id.cid_prog -> unit) -> t -> unit + + (** [to_set r] is the set of called program IDs recorded in [r]. *) + val to_set : t -> Id.Prog.Set.t end = struct type t = { mutable called_programs : Id.Prog.Set.t } @@ -43,6 +46,8 @@ end = struct Id.Prog.Set.mem prog calls.called_programs let iter f calls = Id.Prog.Set.iter f calls.called_programs + + let to_set calls = calls.called_programs end type calls_record = CallsRecord.t @@ -206,6 +211,10 @@ module CallGraphState : sig graph. *) val add_program_calls_record : t -> Id.cid_prog -> calls_record -> unit + (** [get_immediate_dependencies state cid] is the set of direct + dependencies of theorem [cid] in [state]. *) + val get_immediate_dependencies : t -> Id.cid_prog -> Id.Prog.Set.t + (** [compute_program_call_dependencies state cid] is the set of transitive dependencies of theorem [cid] in [state]. That is, it is the set of nodes reachable from [cid] in the call graph. @@ -219,7 +228,6 @@ module CallGraphState : sig whenever memoized results have to be invalidated and recomputed from scratch. *) val clear_memoized_call_dependencies : t -> unit - [@@warning "-32"] (** [set_program_display_name state cid n] sets [n] as the name to use to refer to [cid] in [state]. This is only used for pretty-printing the @@ -253,10 +261,19 @@ end = struct Id.Prog.Hashtbl.add state.program_calls_records cid calls; clear_memoized_call_dependencies state + let get_immediate_dependencies state cid = + match Id.Prog.Hashtbl.find_opt state.program_calls_records cid with + | Option.None -> Error.raise (Unknown_program cid) + | Option.Some calls_record -> CallsRecord.to_set calls_record + let compute_program_call_dependencies state cid = let to_visit = Queue.create () in let visited = Stdlib.ref Id.Prog.Set.empty in - Queue.push cid to_visit; + (* Add direct dependencies to the [to_visit] queue *) + (match Id.Prog.Hashtbl.find_opt state.program_calls_records cid with + | Option.None -> Error.raise (Unknown_program cid) + | Option.Some calls_record -> + CallsRecord.iter (fun x -> Queue.add x to_visit) calls_record); while Bool.not (Queue.is_empty to_visit) do let current_cid = Queue.pop to_visit in if Bool.not (Id.Prog.Set.mem current_cid !visited) then ( @@ -415,6 +432,95 @@ and pp_call_graph_sgn_declaration : | Sgn.Val _ -> () +(** {2 Dependency Data to JSON} *) + +let json_of_location : Location.t -> Yojson.Safe.t = + fun location -> + if Location.is_ghost location then `Null + else + `Assoc + [ ("filename", `String (Location.filename location)) + ; ("start_line", `Int (Location.start_line location)) + ; ("start_column", `Int (Location.start_column location)) + ; ("stop_line", `Int (Location.stop_line location)) + ; ("stop_column", `Int (Location.stop_column location)) + ] + +let rec json_of_call_graph_sgn : state -> Sgn.sgn -> Yojson.Safe.t = + fun state sgn -> + `List + (List.flatten + (List1.to_list (List1.map (json_of_call_graph_sgn_file state) sgn))) + +and json_of_call_graph_sgn_file : state -> Sgn.sgn_file -> Yojson.Safe.t list + = + fun state { Sgn.entries; _ } -> + let programs = + entries + |> List.map (dependencies_to_json_call_graph_sgn_entry state) + |> List.flatten + in + programs + +and dependencies_to_json_call_graph_sgn_entry : + state -> Sgn.entry -> Yojson.Safe.t list = + fun state -> function + | Sgn.Declaration { declaration; _ } -> + json_of_call_graph_sgn_declaration state declaration + | Sgn.Pragma _ + | Sgn.Comment _ -> + [] + +and json_of_cid_prog : Id.cid_prog -> Yojson.Safe.t = + fun cid -> `Int (Id.Prog.to_int cid) + +and json_of_call_graph_sgn_declaration : + state -> Sgn.decl -> Yojson.Safe.t list = + fun state -> function + | Sgn.Theorem { cid; _ } -> + let display_name = + cid |> CallGraphState.lookup_program_display_name state + in + let immediate_dependencies = + cid + |> CallGraphState.get_immediate_dependencies state + |> Id.Prog.Set.to_seq |> Seq.map json_of_cid_prog |> List.of_seq + in + let dependencies = + cid + |> CallGraphState.compute_program_call_dependencies state + |> Id.Prog.Set.to_seq |> Seq.map json_of_cid_prog |> List.of_seq + in + [ `Assoc + [ ("id", json_of_cid_prog cid) + ; ( "qualified_identifier" + , `String (Qualified_identifier.show display_name) ) + ; ( "location" + , json_of_location (Qualified_identifier.location display_name) + ) + ; ("immediate_dependencies", `List immediate_dependencies) + ; ("transitive_dependencies", `List dependencies) + ] + ] + | Sgn.Recursive_declarations { declarations; _ } -> + List1.to_list declarations + |> List.map (json_of_call_graph_sgn_declaration state) + |> List.flatten + | Sgn.Module { entries; _ } -> + entries + |> List.map (dependencies_to_json_call_graph_sgn_entry state) + |> List.flatten + | Sgn.Typ _ + | Sgn.Const _ + | Sgn.CompTyp _ + | Sgn.CompCotyp _ + | Sgn.CompConst _ + | Sgn.CompDest _ + | Sgn.CompTypAbbrev _ + | Sgn.Schema _ + | Sgn.Val _ -> + [] + (** {2 Driver} *) (** CLI usage: [dune exec beluga_call_graph ./path-to-signature.cfg] *) @@ -425,7 +531,9 @@ let main () = | [ file ] -> let _, sgn = Load.load_fresh file in let call_graph = construct_call_graph_state sgn in - pp_call_graph_sgn call_graph Format.std_formatter sgn + Format.fprintf Format.std_formatter "%a@." + (Yojson.Safe.pretty_print ~std:true) + (json_of_call_graph_sgn call_graph sgn) | [] -> Format.fprintf Format.err_formatter "Provide the file path to the Beluga signature.@."; diff --git a/src/beluga/dune b/src/beluga/dune index dc22237ff..810757285 100644 --- a/src/beluga/dune +++ b/src/beluga/dune @@ -9,5 +9,5 @@ (name callGraph) (public_name beluga_call_graph) (package beluga) - (libraries support beluga beluga_syntax) + (libraries support beluga beluga_syntax yojson) (modules callGraph))