Skip to content

Commit

Permalink
Fix register references in Sail->SV
Browse files Browse the repository at this point in the history
  • Loading branch information
Alasdair committed Jun 13, 2024
1 parent 5d9ce75 commit 63ad7c3
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 56 deletions.
51 changes: 51 additions & 0 deletions src/sail_sv_backend/generate_primop2.ml
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,16 @@ module type S = sig
val generate_module : at:Parse_ast.l -> string -> (cval list -> ctyp -> string) option

val get_generated_library_defs : unit -> sv_def list

val hex_str : unit -> string

val dec_str : unit -> string
end

module Make
(Config : sig
val max_unknown_bitvector_width : int
val max_unknown_integer_width : int
end)
() : S = struct
let generated_library_defs = ref (StringSet.empty, [])
Expand Down Expand Up @@ -259,6 +264,52 @@ module Make
}
)

let hex_str () =
register_library_def "sail_hex_str" (fun () ->
let i = primop_name "i" in
let s = primop_name "s" in
SVD_fundef
{
function_name = SVN_string "sail_hex_str";
return_type = Some CT_string;
params = [(mk_id "i", CT_lint)];
body =
mk_statement
(SVS_block
(List.map mk_statement
[SVS_var (s, CT_string, None); svs_raw "s.hextoa(i)" ~inputs:[i] ~outputs:[s]; SVS_return (Var s)]
)
);
}
)

let dec_str () =
register_library_def "sail_dec_str" (fun () ->
let i = primop_name "i" in
let s = primop_name "s" in
SVD_fundef
{
function_name = SVN_string "sail_dec_str";
return_type = Some CT_string;
params = [(mk_id "i", CT_lint)];
body =
mk_statement
(SVS_block
(List.map mk_statement
[SVS_var (s, CT_string, None); svs_raw "s.itoa(i)" ~inputs:[i] ~outputs:[s]; SVS_return (Var s)]
)
);
}
)

let unary_module l gen =
Some
(fun args ret_ctyp ->
match (args, ret_ctyp) with
| [v], ret_ctyp -> gen v ret_ctyp
| _ -> Reporting.unreachable l __POS__ "Incorrect arity given to unary module generator"
)

let binary_module l gen =
Some
(fun args ret_ctyp ->
Expand Down
211 changes: 157 additions & 54 deletions src/sail_sv_backend/jib_sv.ml
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,16 @@ open Sv_ir

module IntSet = Util.IntSet

class footprint_visitor ctx registers reads writes need_stdout need_stderr : jib_visitor =
class footprint_visitor ctx registers references reads writes need_stdout need_stderr : jib_visitor =
object
inherit empty_jib_visitor

method! vctyp _ = SkipChildren
method! vctyp =
function
| CT_ref ctyp ->
references := CTSet.add ctyp !references;
DoChildren
| _ -> DoChildren

method! vcval =
function
Expand Down Expand Up @@ -184,9 +189,19 @@ let collect_spec_info ctx cdefs =
| CDEF_aux (CDEF_fundef (f, _, _, body), _) ->
let reads = ref IdSet.empty in
let writes = ref IdSet.empty in
let references = ref CTSet.empty in
let need_stdout = ref false in
let need_stderr = ref false in
let _ = visit_cdef (new footprint_visitor ctx registers reads writes need_stdout need_stderr) cdef in
let _ =
visit_cdef (new footprint_visitor ctx registers references reads writes need_stdout need_stderr) cdef
in
CTSet.iter
(fun ctyp ->
IdSet.iter
(fun reg -> writes := IdSet.add reg !writes)
(Option.value ~default:IdSet.empty (CTMap.find_opt ctyp register_ctyp_map))
)
!references;
Bindings.add f
{
direct_reads = !reads;
Expand Down Expand Up @@ -253,6 +268,7 @@ module Make (Config : CONFIG) = struct
Generate_primop2.Make
(struct
let max_unknown_bitvector_width = Config.max_unknown_bitvector_width
let max_unknown_integer_width = Config.max_unknown_integer_width
end)
()

Expand Down Expand Up @@ -287,6 +303,8 @@ module Make (Config : CONFIG) = struct

let pp_id id = string (pp_id_string id)

let pp_sv_name_string = function SVN_id id -> pp_id_string id | SVN_string s -> s

let pp_sv_name = function SVN_id id -> pp_id id | SVN_string s -> string s

let sv_type_id_string id = "t_" ^ pp_id_string id
Expand Down Expand Up @@ -368,13 +386,13 @@ module Make (Config : CONFIG) = struct
| _ -> Reporting.unreachable l __POS__ "string_of_bits"

let dec_str l = function
| CT_lint -> "sail_dec_str"
| CT_lint -> Primops.dec_str ()
| CT_fint sz when Config.nostrings -> Generate_primop.dec_str_fint_stub sz
| CT_fint sz -> Generate_primop.dec_str_fint sz
| _ -> Reporting.unreachable l __POS__ "dec_str"

let hex_str l = function
| CT_lint -> "sail_hex_str"
| CT_lint -> Primops.hex_str ()
| CT_fint sz when Config.nostrings -> Generate_primop.hex_str_fint_stub sz
| CT_fint sz -> Generate_primop.hex_str_fint sz
| _ -> Reporting.unreachable l __POS__ "hex_str"
Expand Down Expand Up @@ -766,7 +784,7 @@ module Make (Config : CONFIG) = struct
if packed then pp_smt v ^^ dot ^^ string "value" ^^ dot ^^ packed_ctor else pp_smt v ^^ dot ^^ pp_id ctor
| Field (_, field, v) -> pp_smt v ^^ dot ^^ pp_id field
| Ite (cond, then_exp, else_exp) ->
separate space [pp_smt_parens cond; char '?'; pp_smt_parens then_exp; char ':'; pp_smt_parens else_exp]
parens (separate space [pp_smt_parens cond; char '?'; pp_smt_parens then_exp; char ':'; pp_smt_parens else_exp])
| Empty_list -> string "{}"
| Hd (op, arg) -> begin
match tails arg with
Expand Down Expand Up @@ -877,7 +895,7 @@ module Make (Config : CONFIG) = struct
in
let updates, lexp = svir_clexp clexp in
wrap (with_updates l updates (SVS_assign (lexp, value)))
| I_funcall (creturn, _, (id, _), args) ->
| I_funcall (creturn, preserve_name, (id, _), args) ->
if ctx_is_extern id ctx then (
let name = ctx_get_extern id ctx in
if name = "sail_assert" then (
Expand Down Expand Up @@ -948,7 +966,8 @@ module Make (Config : CONFIG) = struct
else
let* args = mapM Smt.smt_cval args in
let updates, ret = svir_creturn creturn in
wrap (with_updates l updates (SVS_call (ret, SVN_id id, args)))
if preserve_name then wrap (with_updates l updates (SVS_call (ret, SVN_string (string_of_id id), args)))
else wrap (with_updates l updates (SVS_call (ret, SVN_id id, args)))
| I_block instrs ->
let* statements = fmap Util.option_these (mapM (svir_instr ctx) instrs) in
wrap (svs_block statements)
Expand Down Expand Up @@ -1106,6 +1125,19 @@ module Make (Config : CONFIG) = struct

method! vinstr (I_aux (aux, iannot) as no_change) =
match aux with
| I_copy (CL_addr (CL_id (id, CT_ref reg_ctyp)), cval) -> begin
let regs = Option.value ~default:IdSet.empty (CTMap.find_opt reg_ctyp spec_info.register_ctyp_map) in

let encoded = "sail_reg_assign_" ^ Util.zencode_string (string_of_ctyp reg_ctyp) in
let reads = List.map (fun id -> V_id (Name (id, -1), reg_ctyp)) (IdSet.elements regs) in
let writes = List.map (fun id -> CL_id (Name (id, -1), reg_ctyp)) (IdSet.elements regs) in
ChangeTo
(I_aux
( I_funcall (CR_multi writes, true, (mk_id encoded, []), V_id (id, CT_ref reg_ctyp) :: cval :: reads),
iannot
)
)
end
| I_funcall (CR_one clexp, ext, (f, []), args) -> begin
match Bindings.find_opt f spec_info.footprints with
| Some footprint ->
Expand Down Expand Up @@ -1148,6 +1180,21 @@ module Make (Config : CONFIG) = struct
iannot
)
)
else if name = "reg_deref" then (
match args with
| [cval] -> begin
match cval_ctyp cval with
| CT_ref reg_ctyp ->
let regs =
Option.value ~default:IdSet.empty (CTMap.find_opt reg_ctyp spec_info.register_ctyp_map)
in
let encoded = "sail_reg_deref_" ^ Util.zencode_string (string_of_ctyp reg_ctyp) in
let reads = List.map (fun id -> V_id (Name (id, -1), reg_ctyp)) (IdSet.elements regs) in
ChangeTo (I_aux (I_funcall (CR_one clexp, true, (mk_id encoded, []), cval :: reads), iannot))
| _ -> Reporting.unreachable (snd iannot) __POS__ "Invalid type for reg_deref argument"
end
| _ -> Reporting.unreachable (snd iannot) __POS__ "Invalid arguments for reg_deref"
)
else SkipChildren
)
else (
Expand Down Expand Up @@ -1571,7 +1618,36 @@ module Make (Config : CONFIG) = struct
^^ nest 4 (hardline ^^ separate_map hardline pp_def m.defs)
^^ hardline ^^ string "endmodule"

and pp_fundef f = string "function"
and pp_fundef f =
let ret_ty, typedef =
match f.return_type with
| Some ret_ctyp ->
let ret_ty, index_ty = sv_ctyp ret_ctyp in
begin
match index_ty with
| Some index ->
let encoded = Util.zencode_string (string_of_ctyp ret_ctyp) in
let new_ty = string ("t_" ^ pp_sv_name_string f.function_name ^ "_" ^ encoded) in
( new_ty,
separate space [string "typedef"; string ret_ty; new_ty; string index] ^^ semi ^^ twice hardline
)
| None -> (string ret_ty, empty)
end
| None -> (string "void", empty)
in
let param_docs = List.map (fun (param, ctyp) -> wrap_type ctyp (pp_id param)) f.params in
let block_terminator last = if last then semi else semi ^^ hardline in
let pp_body = function
| SVS_aux (SVS_block statements, _) ->
concat (Util.map_last (fun last -> pp_statement ~terminator:(block_terminator last)) statements)
| statement -> pp_statement ~terminator:semi statement
in
typedef
^^ separate space [string "function"; string "automatic"; ret_ty; pp_sv_name f.function_name]
^^ parens (separate (comma ^^ space) param_docs)
^^ semi
^^ nest 4 (hardline ^^ pp_body f.body)
^^ hardline ^^ string "endfunction"

and pp_def = function
| SVD_var (id, ctyp) -> wrap_type ctyp (pp_name id) ^^ semi
Expand Down Expand Up @@ -1649,12 +1725,14 @@ module Make (Config : CONFIG) = struct
)
CTMap.empty cdefs

let sv_register_references cdefs =
let rmap = collect_registers cdefs in
let sv_register_references spec_info =
let rmap = spec_info.register_ctyp_map in
let reg_ref id = "SAIL_REG_" ^ Util.zencode_upper_string (string_of_id id) in
let check reg = parens (separate space [char 'r'; string "=="; string (reg_ref reg)]) in
let reg_ref_enums =
List.map
(fun (ctyp, regs) ->
let regs = IdSet.elements regs in
separate space [string "typedef"; string "enum"; lbrace]
^^ nest 4 (hardline ^^ separate_map (comma ^^ hardline) (fun r -> string (reg_ref r)) regs)
^^ hardline ^^ rbrace ^^ space
Expand All @@ -1667,6 +1745,7 @@ module Make (Config : CONFIG) = struct
let reg_ref_functions =
List.map
(fun (ctyp, regs) ->
let regs = IdSet.elements regs in
let encoded = Util.zencode_string (string_of_ctyp ctyp) in
let sv_ty, index_ty = sv_ctyp ctyp in
let sv_ty, typedef =
Expand All @@ -1676,49 +1755,73 @@ module Make (Config : CONFIG) = struct
(new_ty, separate space [string "typedef"; string sv_ty; new_ty; string index] ^^ semi ^^ twice hardline)
| None -> (string sv_ty, empty)
in
typedef
^^ separate space [string "function"; string "automatic"; sv_ty]
^^ space
^^ string ("sail_reg_deref_" ^ encoded)
^^ parens (string ("sail_reg_" ^ encoded) ^^ space ^^ char 'r')
^^ semi
^^ nest 4
(hardline
^^ separate_map hardline
(fun reg ->
separate space
[
string "if";
parens (separate space [char 'r'; string "=="; string (reg_ref reg)]);
string "begin";
]
^^ nest 4 (hardline ^^ string "return" ^^ space ^^ pp_id reg ^^ semi)
^^ hardline ^^ string "end" ^^ semi
)
regs
)
^^ hardline ^^ string "endfunction" ^^ twice hardline
^^ separate space [string "function"; string "automatic"; string "void"]
^^ space
^^ string ("sail_reg_assign_" ^ encoded)
^^ parens (separate space [string ("sail_reg_" ^ encoded); char 'r' ^^ comma; wrap_type ctyp (char 'v')])
^^ semi
^^ nest 4
(hardline
^^ separate_map hardline
(fun reg ->
separate space
[
string "if";
parens (separate space [char 'r'; string "=="; string (reg_ref reg)]);
string "begin";
]
^^ nest 4 (hardline ^^ pp_id reg ^^ space ^^ equals ^^ space ^^ char 'v' ^^ semi)
^^ hardline ^^ string "end" ^^ semi
)
regs
)
^^ hardline ^^ string "endfunction"
let port ~input ty v = separate space [string (if input then "input" else "output"); ty; v] in
let assign_module =
let ports =
port ~input:true (string ("sail_reg_" ^ encoded)) (char 'r')
:: port ~input:true sv_ty (char 'v')
:: List.map (fun r -> port ~input:true sv_ty (pp_id (prepend_id "in_" r))) regs
@ List.map (fun r -> port ~input:false sv_ty (pp_id (prepend_id "out_" r))) regs
in
let assignment reg =
separate space
[
pp_id (prepend_id "out_" reg);
equals;
check reg;
char '?';
char 'v';
colon;
pp_id (prepend_id "in_" reg);
]
in
let comb =
nest 4 (string "begin" ^^ hardline ^^ separate_map (semi ^^ hardline) assignment regs ^^ semi)
^^ hardline ^^ string "end" ^^ semi
in
string "module" ^^ space
^^ string ("sail_reg_assign_" ^ encoded)
^^ nest 4 (lparen ^^ hardline ^^ separate (comma ^^ hardline) ports)
^^ hardline ^^ rparen ^^ semi
^^ nest 4 (hardline ^^ string "always_comb" ^^ space ^^ comb)
^^ hardline ^^ string "endmodule"
in
let deref_module =
let ports =
port ~input:true (string ("sail_reg_" ^ encoded)) (char 'r')
:: List.map (fun r -> port ~input:true sv_ty (pp_id (prepend_id "in_" r))) regs
@ [port ~input:false sv_ty (char 'v')]
in
let cases =
List.map
(fun reg ->
let assign = separate space [char 'v'; equals; pp_id (prepend_id "in_" reg)] in
(check reg, assign)
)
regs
in
let ifstmt =
match cases with
| [(_, assign)] -> assign
| _ ->
let ifs =
Util.map_last
(fun last (check, assign) ->
if last then nest 4 (hardline ^^ assign)
else string "if" ^^ space ^^ check ^^ nest 4 (hardline ^^ assign ^^ semi)
)
cases
in
separate (hardline ^^ string "else" ^^ space) ifs
in
string "module" ^^ space
^^ string ("sail_reg_deref_" ^ encoded)
^^ nest 4 (lparen ^^ hardline ^^ separate (comma ^^ hardline) ports)
^^ hardline ^^ rparen ^^ semi
^^ nest 4 (hardline ^^ string "always_comb" ^^ space ^^ ifstmt ^^ semi)
^^ hardline ^^ string "endmodule"
in
typedef ^^ assign_module ^^ twice hardline ^^ deref_module
)
(CTMap.bindings rmap)
|> separate (twice hardline)
Expand Down
Loading

0 comments on commit 63ad7c3

Please sign in to comment.