Skip to content

Commit

Permalink
Track memory management for Ndarray
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Jul 29, 2024
1 parent a8ed541 commit f2377cc
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 16 deletions.
12 changes: 8 additions & 4 deletions arrayjit/lib/cc_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ let%track_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_
Hashtbl.fold lowered.traced_store ~init:ctx_arrays ~f:(fun ~key:tn ~data:_ ctx_arrays ->
match Map.find ctx_arrays tn with
| None ->
let debug = "CC compile-time ctx array for " ^ Tn.debug_name tn in
let data =
Ndarray.create_array tn.Tn.prec ~dims:(Lazy.force tn.dims)
Ndarray.create_array ~debug tn.Tn.prec ~dims:(Lazy.force tn.dims)
@@ Constant_fill { values = [| 0. |]; strict = false }
in
Map.add_exn ctx_arrays ~key:tn ~data
Expand Down Expand Up @@ -120,13 +121,14 @@ let%track_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_
let log_fname = base_name ^ ".log" in
let libname = base_name ^ ".so" in
(try Stdlib.Sys.remove log_fname with _ -> ());
(try Stdlib.Sys.remove libname with _ -> ());
let cmdline =
Printf.sprintf "%s %s -O%d -o %s --shared >> %s 2>&1" (compiler_command ()) pp_file.f_name
(optimization_level ()) libname log_fname
in
let _rc = Stdlib.Sys.command cmdline in
(* FIXME: don't busy wait *)
while not @@ Stdlib.Sys.file_exists log_fname do
while not @@ (Stdlib.Sys.file_exists libname && Stdlib.Sys.file_exists log_fname) do
()
done;
let result = Dl.dlopen ~filename:libname ~flags:[ RTLD_NOW; RTLD_DEEPBIND ] in
Expand All @@ -141,8 +143,9 @@ let%track_sexp compile_batch ~names ~opt_ctx_arrays bindings
Hashtbl.fold lowered.traced_store ~init:ctx_arrays ~f:(fun ~key:tn ~data:_ ctx_arrays ->
match Map.find ctx_arrays tn with
| None ->
let debug = "CC compile-time ctx array for " ^ Tn.debug_name tn in
let data =
Ndarray.create_array tn.Tn.prec ~dims:(Lazy.force tn.dims)
Ndarray.create_array ~debug tn.Tn.prec ~dims:(Lazy.force tn.dims)
@@ Constant_fill { values = [| 0. |]; strict = false }
in
Map.add_exn ctx_arrays ~key:tn ~data
Expand Down Expand Up @@ -218,7 +221,8 @@ let%track_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
let f = function
| Some arr -> arr
| None ->
Ndarray.create_array tn.Tn.prec ~dims:(Lazy.force tn.dims)
let debug = "CC link-time ctx array for " ^ Tn.debug_name tn in
Ndarray.create_array ~debug tn.Tn.prec ~dims:(Lazy.force tn.dims)
@@ Constant_fill { values = [| 0. |]; strict = false }
in
Map.update ctx_arrays tn ~f
Expand Down
6 changes: 4 additions & 2 deletions arrayjit/lib/gcc_backend.gccjit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,9 @@ let prepare_node ~debug_log_zero_out ~get_ident ctx nodes traced_store ctx_nodes
| From_context, Ctx_arrays ctx_arrays -> (
match Map.find !ctx_arrays tn with
| None ->
let debug = "GCCJIT compile-time ctx array for " ^ Tn.debug_name tn in
let data =
Ndarray.create_array tn.Tn.prec ~dims
Ndarray.create_array ~debug tn.Tn.prec ~dims
@@ Constant_fill { values = [| 0. |]; strict = false }
in
ctx_arrays := Map.add_exn !ctx_arrays ~key:tn ~data;
Expand Down Expand Up @@ -805,7 +806,8 @@ let%track_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
let f = function
| Some arr -> arr
| None ->
Ndarray.create_array tn.Tn.prec ~dims:(Lazy.force tn.dims)
let debug = "GCCJIT link-time ctx array for " ^ Tn.debug_name tn in
Ndarray.create_array ~debug tn.Tn.prec ~dims:(Lazy.force tn.dims)
@@ Constant_fill { values = [| 0. |]; strict = false }
in
Map.update ctx_arrays tn ~f
Expand Down
31 changes: 23 additions & 8 deletions arrayjit/lib/ndarray.ml
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
open Base
(** N-dimensional arrays: a precision-handling wrapper for [Bigarray.Genarray] and its utilities. *)

module Debug_runtime = Utils.Debug_runtime

[%%global_debug_log_level Nothing]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

module A = Bigarray.Genarray

(** {2 *** Handling of precisions ***} *)
Expand Down Expand Up @@ -145,13 +150,6 @@ let create_bigarray (type ocaml elt_t) (prec : (ocaml, elt_t) Ops.precision) ~di
Unix.close fd;
ba

let create_array prec ~dims init_op =
let f prec = as_array prec @@ create_bigarray prec ~dims init_op in
Ops.map_prec { f } prec

let empty_array prec =
create_array prec ~dims:[||] (Constant_fill { values = [| 0.0 |]; strict = false })

(** {2 *** Accessing ***} *)

type 'r map_as_bigarray = { f : 'ocaml 'elt_t. ('ocaml, 'elt_t) bigarray -> 'r }
Expand Down Expand Up @@ -363,13 +361,30 @@ let retrieve_flat_values arr =
iter 0;
Array.of_list_rev !result

(** {2 *** Printing ***} *)
(** {2 *** Creating ***} *)

let c_ptr_to_string nd =
let prec = get_prec nd in
let f arr = Ops.ptr_to_string (Ctypes.bigarray_start Ctypes_static.Genarray arr) prec in
map { f } nd

let create_array ~debug:_debug prec ~dims init_op =
let f prec = as_array prec @@ create_bigarray prec ~dims init_op in
let result = Ops.map_prec { f } prec in
if Utils.settings.with_debug_level > 2 then
[%debug_sexp
[%log_entry
"create_array";
[%log _debug, c_ptr_to_string result]]];
let%debug_sexp debug_finalizer _result = [%log "Deleting", _debug, c_ptr_to_string _result] in
if Utils.settings.with_debug_level > 2 then Stdlib.Gc.finalise debug_finalizer result;
result

let empty_array prec =
create_array prec ~dims:[||] (Constant_fill { values = [| 0.0 |]; strict = false })

(** {2 *** Printing ***} *)

(** Dimensions to string, ["x"]-separated, e.g. 1x2x3 for batch dims 1, input dims 3, output dims 2.
Outputs ["-"] for empty dimensions. *)
let int_dims_to_string ?(with_axis_numbers = false) dims =
Expand Down
4 changes: 3 additions & 1 deletion arrayjit/lib/tnode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,11 @@ end)
let registry = Registry.create 16

let create prec ~id ~label ~dims init_op =
let debug = "Host array for " ^ get_debug_name ~id ~label () in
let rec array =
lazy
(if is_hosted_force tn 30 then Some (Nd.create_array prec ~dims:(Lazy.force dims) init_op)
(if is_hosted_force tn 30 then
Some (Nd.create_array ~debug prec ~dims:(Lazy.force dims) init_op)
else None)
and tn =
{
Expand Down
5 changes: 4 additions & 1 deletion lib/tensor.ml
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,10 @@ let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?
Stdlib.Format.pp_set_geometry Stdlib.Format.str_formatter ~max_indent:!max_sublabel_length
~margin:(!max_sublabel_length * 2);
let dims = Array.concat_map [| batch_ds; output_ds; input_ds |] ~f:Array.of_list in
let ndarr = Nd.create_array Arrayjit.Ops.double ~dims (Constant_fill { values; strict }) in
let debug = "Temporary array for pretty-printing" in
let ndarr =
Nd.create_array ~debug Arrayjit.Ops.double ~dims (Constant_fill { values; strict })
in
let ( ! ) = List.length in
Nd.pp_array_inline ~num_batch_axes:!batch_ds ~num_output_axes:!output_ds
~num_input_axes:!input_ds Stdlib.Format.str_formatter ndarr;
Expand Down

0 comments on commit f2377cc

Please sign in to comment.