Skip to content

Commit

Permalink
Migrate to using local debug runtimes instead of runtime passing
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Aug 9, 2024
1 parent 930998d commit 8cac8f7
Show file tree
Hide file tree
Showing 16 changed files with 90 additions and 68 deletions.
2 changes: 2 additions & 0 deletions arrayjit/lib/assignments.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ module Lazy = Utils.Lazy
module Tn = Tnode
module Debug_runtime = Utils.Debug_runtime

let _get_local_debug_runtime = Utils._get_local_debug_runtime

[%%global_debug_log_level Nothing]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

Expand Down
5 changes: 4 additions & 1 deletion arrayjit/lib/backend_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ open Base
module Lazy = Utils.Lazy
module Debug_runtime = Utils.Debug_runtime

let _get_local_debug_runtime = Utils._get_local_debug_runtime

[%%global_debug_log_level Nothing]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

Expand Down Expand Up @@ -305,7 +307,8 @@ struct
Sexp.Atom
(if B.is_in_context node then "From_context"
else if Hash_set.mem is_global tn then "Constant_from_host"
else if Tn.is_virtual_force tn 3331 then "Virtual" else "Local_only")
else if Tn.is_virtual_force tn 3331 then "Virtual"
else "Local_only")
in
if not @@ Utils.sexp_mem ~elem:backend_info tn.backend_info then
tn.backend_info <- Utils.sexp_append ~elem:backend_info tn.backend_info;
Expand Down
76 changes: 33 additions & 43 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ open Base
open Backend_utils.Types
module Debug_runtime = Utils.Debug_runtime

let _get_local_debug_runtime = Utils._get_local_debug_runtime

[%%global_debug_log_level Nothing]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

Expand Down Expand Up @@ -186,8 +188,8 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
mutable dev_pos : task_list;
mutable dev_previous_pos : task_list;
mut : (Mut.t[@sexp.opaque]);
host_wait_for_idle : (Condition.t[@sexp.opaque]);
dev_wait_for_work : (Condition.t[@sexp.opaque]);
host_wait_for_idle : (Stdlib.Condition.t[@sexp.opaque]);
dev_wait_for_work : (Stdlib.Condition.t[@sexp.opaque]);
mutable is_idle : bool;
mutable host_is_waiting : bool; (** The host is waiting for this specific device. *)
}
Expand Down Expand Up @@ -233,7 +235,7 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
while (not d.is_idle) && d.keep_spinning do
(* Stdlib.Printf.printf "DEBUG: await waiting is_idle=%b host_is_waiting=%b
is_empty=%b\n%!" d.is_idle d.host_is_waiting (is_dev_queue_empty d); *)
Condition.wait d.host_wait_for_idle d.mut
Stdlib.Condition.wait d.host_wait_for_idle d.mut
done;
d.host_is_waiting <- false);
(* Stdlib.Printf.printf "DEBUG: await unlocking is_idle=%b host_is_waiting=%b is_empty=%b\n%!"
Expand All @@ -257,7 +259,7 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
Mut.lock d.mut;
(* Stdlib.Printf.printf "DEBUG: schedule task broadcasting is_idle=%b host_is_waiting=%b
is_empty=%b\n%!" d.is_idle d.host_is_waiting (is_dev_queue_empty d); *)
Condition.broadcast d.dev_wait_for_work;
Stdlib.Condition.broadcast d.dev_wait_for_work;
Mut.unlock d.mut)

let global_run_no = ref 0
Expand All @@ -269,11 +271,7 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
{
hd =
Tnode.Task
{
context_lifetime = ();
description = "root of task queue";
work = (fun _rt () -> ());
};
{ context_lifetime = (); description = "root of task queue"; work = (fun () -> ()) };
tl = Empty;
}
in
Expand All @@ -286,8 +284,8 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
dev_previous_pos = init_pos;
mut = Mut.create ();
is_idle = false;
host_wait_for_idle = Condition.create ();
dev_wait_for_work = Condition.create ();
host_wait_for_idle = Stdlib.Condition.create ();
dev_wait_for_work = Stdlib.Condition.create ();
host_is_waiting = false;
}
in
Expand All @@ -313,11 +311,11 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
(* Stdlib.Printf.printf "DEBUG: worker empty broadcasting is_idle=%b
host_is_waiting=%b is_empty=%b\n%!" state.is_idle state.host_is_waiting
(is_dev_queue_empty state); *)
if state.host_is_waiting then Condition.broadcast state.host_wait_for_idle;
if state.host_is_waiting then Stdlib.Condition.broadcast state.host_wait_for_idle;
(* Stdlib.Printf.printf "DEBUG: worker empty waiting is_idle=%b host_is_waiting=%b
is_empty=%b\n%!" state.is_idle state.host_is_waiting (is_dev_queue_empty
state); *)
Condition.wait state.dev_wait_for_work state.mut
Stdlib.Condition.wait state.dev_wait_for_work state.mut
done;
state.is_idle <- false);
(* Stdlib.Printf.printf "DEBUG: worker empty unlocking is_idle=%b host_is_waiting=%b
Expand All @@ -327,7 +325,7 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
task_list)]; *)
state.dev_pos <- Utils.tl_exn state.dev_previous_pos
| Cons { hd; tl } ->
Tnode.run _debug_runtime hd;
Tnode.run hd;
(* [%log "WORK WHILE LOOP: AFTER WORK"]; *)
state.dev_previous_pos <- state.dev_pos;
state.dev_pos <- tl
Expand All @@ -349,7 +347,7 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
}

let%diagn_sexp make_work device (Tnode.Task { description; _ } as task) =
let%diagn_rt_sexp work () = schedule_task device task in
let%diagn_l_sexp work () = schedule_task device task in
Tnode.Task
{
context_lifetime = task;
Expand Down Expand Up @@ -410,10 +408,9 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
@@ Option.map (Backend.get_buffer tn context.ctx) ~f:(fun c_arr ->
match tn.Tnode.array with
| (lazy (Some h_arr)) ->
let work rt () =
Backend.host_to_buffer ~rt h_arr ~dst:c_arr;
let%diagn_l_sexp work () =
Backend.host_to_buffer h_arr ~dst:c_arr;
if Utils.settings.with_debug_level > 0 then
let module Debug_runtime = (val rt) in
[%diagn_sexp
[%log_entry
"from_host " ^ Tnode.debug_name tn;
Expand Down Expand Up @@ -450,10 +447,9 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
@@ Option.map (Backend.get_buffer tn context.ctx) ~f:(fun c_arr ->
match tn.Tnode.array with
| (lazy (Some h_arr)) ->
let work rt () =
Backend.buffer_to_host ~rt h_arr ~src:c_arr;
let%diagn_l_sexp work () =
Backend.buffer_to_host h_arr ~src:c_arr;
if Utils.settings.with_debug_level > 0 then
let module Debug_runtime = (val rt) in
[%diagn_sexp
[%log_entry
"to_host " ^ Tnode.debug_name tn;
Expand Down Expand Up @@ -501,13 +497,13 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
let work =
(* TODO: log the operation if [Utils.settings.with_log_level > 0]. *)
match into_merge_buffer with
| No -> fun rt () -> Backend.to_buffer ~rt tn ~dst ~src:src.ctx
| No -> fun () -> Backend.to_buffer tn ~dst ~src:src.ctx
| Streaming ->
fun _rt () ->
fun () ->
dev.merge_buffer :=
Option.map ~f:(fun ptr -> (ptr, tn)) @@ Backend.get_buffer tn src.ctx
| Copy ->
fun rt () ->
fun () ->
let size_in_bytes = Tnode.size_in_bytes tn in
let allocated_capacity =
Option.value ~default:0 @@ Option.map dev.allocated_buffer ~f:snd
Expand All @@ -519,7 +515,7 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
size_in_bytes );
let merge_ptr = fst @@ Option.value_exn dev.allocated_buffer in
dev.merge_buffer := Some (merge_ptr, tn);
Backend.to_buffer ~rt tn ~dst:merge_ptr ~src:src.ctx
Backend.to_buffer tn ~dst:merge_ptr ~src:src.ctx
in
let description =
"device_to_device " ^ Tnode.debug_name tn ^ " dst " ^ Int.to_string dev.ordinal ^ " src "
Expand All @@ -542,7 +538,7 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
let wait_for_finish device =
await device;
device.state.keep_spinning <- false;
Condition.broadcast device.state.dev_wait_for_work
Stdlib.Condition.broadcast device.state.dev_wait_for_work
in
Array.iter devices ~f:(Option.iter ~f:wait_for_finish);
let cleanup ordinal device =
Expand Down Expand Up @@ -643,11 +639,7 @@ module Pipes_multicore_backend (Backend : No_device_backend) : Backend = struct
{
hd =
Tnode.Task
{
context_lifetime = ();
description = "root of task queue";
work = (fun _rt () -> ());
};
{ context_lifetime = (); description = "root of task queue"; work = (fun () -> ()) };
tl = Empty;
}
in
Expand Down Expand Up @@ -681,7 +673,7 @@ module Pipes_multicore_backend (Backend : No_device_backend) : Backend = struct
(* [%log "WORK WHILE LOOP: EMPTY AFTER WAIT -- dev pos:", (state.dev_pos : task_list)]; *)
state.dev_pos <- Utils.tl_exn state.dev_previous_pos
| Cons { hd; tl } ->
Tnode.run _debug_runtime hd;
Tnode.run hd;
(* [%log "WORK WHILE LOOP: AFTER WORK"]; *)
state.dev_previous_pos <- state.dev_pos;
state.dev_pos <- tl
Expand All @@ -705,7 +697,7 @@ module Pipes_multicore_backend (Backend : No_device_backend) : Backend = struct
}

let%diagn_sexp make_work device (Tnode.Task { context_lifetime; description; _ } as task) =
let%diagn_rt_sexp work () = schedule_task device task in
let%diagn_l_sexp work () = schedule_task device task in
Tnode.Task
{
context_lifetime;
Expand Down Expand Up @@ -766,10 +758,9 @@ module Pipes_multicore_backend (Backend : No_device_backend) : Backend = struct
@@ Option.map (Backend.get_buffer tn context.ctx) ~f:(fun c_arr ->
match tn.Tnode.array with
| (lazy (Some h_arr)) ->
let work rt () =
Backend.host_to_buffer ~rt h_arr ~dst:c_arr;
let%diagn_l_sexp work () =
Backend.host_to_buffer h_arr ~dst:c_arr;
if Utils.settings.with_debug_level > 0 then
let module Debug_runtime = (val rt) in
[%diagn_sexp
[%log_entry
"from_host " ^ Tnode.debug_name tn;
Expand Down Expand Up @@ -806,10 +797,9 @@ module Pipes_multicore_backend (Backend : No_device_backend) : Backend = struct
@@ Option.map (Backend.get_buffer tn context.ctx) ~f:(fun c_arr ->
match tn.Tnode.array with
| (lazy (Some h_arr)) ->
let work rt () =
Backend.buffer_to_host ~rt h_arr ~src:c_arr;
let%diagn_l_sexp work () =
Backend.buffer_to_host h_arr ~src:c_arr;
if Utils.settings.with_debug_level > 0 then
let module Debug_runtime = (val rt) in
[%diagn_sexp
[%log_entry
"to_host " ^ Tnode.debug_name tn;
Expand Down Expand Up @@ -856,13 +846,13 @@ module Pipes_multicore_backend (Backend : No_device_backend) : Backend = struct
let schedule dst_ptr =
let work =
match into_merge_buffer with
| No -> fun rt () -> Backend.to_buffer ~rt tn ~dst:dst_ptr ~src:src.ctx
| No -> fun () -> Backend.to_buffer tn ~dst:dst_ptr ~src:src.ctx
| Streaming ->
fun _rt () ->
fun () ->
dev.merge_buffer :=
Option.map ~f:(fun ptr -> (ptr, tn)) @@ Backend.get_buffer tn src.ctx
| Copy ->
fun rt () ->
fun () ->
let size_in_bytes = Tnode.size_in_bytes tn in
let allocated_capacity =
Option.value ~default:0 @@ Option.map dev.allocated_buffer ~f:snd
Expand All @@ -874,7 +864,7 @@ module Pipes_multicore_backend (Backend : No_device_backend) : Backend = struct
size_in_bytes );
let merge_ptr = fst @@ Option.value_exn dev.allocated_buffer in
dev.merge_buffer := Some (merge_ptr, tn);
Backend.to_buffer ~rt tn ~dst:merge_ptr ~src:src.ctx
Backend.to_buffer tn ~dst:merge_ptr ~src:src.ctx
in
schedule_task dev
(Tnode.Task
Expand Down
6 changes: 4 additions & 2 deletions arrayjit/lib/cc_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ open Base
module Lazy = Utils.Lazy
module Debug_runtime = Utils.Debug_runtime

let _get_local_debug_runtime = Utils._get_local_debug_runtime

[%%global_debug_log_level Nothing]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

Expand Down Expand Up @@ -272,12 +274,12 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
let params = List.rev_map code.params ~f:(fun (_, p) -> p) in
link code.bindings params Ctypes.(void @-> returning void)]
in
let%diagn_rt_sexp work () : unit =
let%diagn_l_sexp work () : unit =
[%log_result name];
Backend_utils.check_merge_buffer ~merge_buffer ~code_node:code.lowered.merge_node;
Indexing.apply run_variadic ();
if Utils.settings.debug_log_from_routines then (
Utils.log_trace_tree _debug_runtime (Stdio.In_channel.read_lines log_file_name);
Utils.log_trace_tree (module Debug_runtime) (Stdio.In_channel.read_lines log_file_name);
Stdlib.Sys.remove log_file_name)
in
( context,
Expand Down
6 changes: 4 additions & 2 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ module Lazy = Utils.Lazy
module Debug_runtime = Utils.Debug_runtime
open Backend_utils.Types

let _get_local_debug_runtime = Utils._get_local_debug_runtime

[%%global_debug_log_level Nothing]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

Expand Down Expand Up @@ -410,7 +412,7 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~glo
{ prior_context with parent = Some prior_context; run_module = Some run_module; global_arrays }
in
Stdlib.Gc.finalise finalize context;
let%diagn_rt_sexp work () : unit =
let%diagn_l_sexp work () : unit =
let log_id = get_global_run_id () in
let log_id_prefix = Int.to_string log_id ^ ": " in
if Utils.settings.with_debug_level > 0 then
Expand Down Expand Up @@ -454,7 +456,7 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~glo
Utils.add_log_processor ~prefix:log_id_prefix @@ fun output ->
[%log_entry
context.label;
Utils.log_trace_tree _debug_runtime output]);
Utils.log_trace_tree (module Debug_runtime) output]);
(* if Utils.settings.debug_log_from_routines then Cu.ctx_set_limit CU_LIMIT_PRINTF_FIFO_SIZE
4096; *)
Cu.launch_kernel func ~grid_dim_x:1 ~block_dim_x:1 ~shared_mem_bytes:0 context.device.stream
Expand Down
6 changes: 4 additions & 2 deletions arrayjit/lib/gcc_backend.gccjit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ open Base
module Lazy = Utils.Lazy
module Debug_runtime = Utils.Debug_runtime

let _get_local_debug_runtime = Utils._get_local_debug_runtime

[%%global_debug_log_level Nothing]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

Expand Down Expand Up @@ -849,12 +851,12 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
indices. *)
link code.bindings (List.rev code.params) Ctypes.(void @-> returning void)]
in
let%diagn_rt_sexp work () : unit =
let%diagn_l_sexp work () : unit =
[%log_result name];
Backend_utils.check_merge_buffer ~merge_buffer ~code_node:code.expected_merge_node;
Indexing.apply run_variadic ();
if Utils.settings.debug_log_from_routines then (
Utils.log_trace_tree _debug_runtime (Stdio.In_channel.read_lines log_file_name);
Utils.log_trace_tree (module Debug_runtime) (Stdio.In_channel.read_lines log_file_name);
Stdlib.Sys.remove log_file_name)
in
( context,
Expand Down
2 changes: 2 additions & 0 deletions arrayjit/lib/low_level.ml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ module Nd = Ndarray
module Tn = Tnode
module Debug_runtime = Utils.Debug_runtime

let _get_local_debug_runtime = Utils._get_local_debug_runtime

[%%global_debug_log_level Nothing]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

Expand Down
2 changes: 2 additions & 0 deletions arrayjit/lib/ndarray.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ open Base

module Debug_runtime = Utils.Debug_runtime

let _get_local_debug_runtime = Utils._get_local_debug_runtime

[%%global_debug_log_level Nothing]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

Expand Down
9 changes: 5 additions & 4 deletions arrayjit/lib/tnode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,25 @@ module Lazy = Utils.Lazy
module Nd = Ndarray
module Debug_runtime = Utils.Debug_runtime

let _get_local_debug_runtime = Utils._get_local_debug_runtime

[%%global_debug_log_level Nothing]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

type task =
| Task : {
context_lifetime : ('a[@sexp.opaque]);
description : string;
work : (module Minidebug_runtime.Debug_runtime) -> unit -> unit;
work : unit -> unit;
}
-> task
[@@deriving sexp_of]

let run debug_runtime (Task task) =
let module Debug_runtime = (val debug_runtime : Minidebug_runtime.Debug_runtime) in
let run (Task task) =
[%diagn_sexp
[%log_entry
task.description;
task.work debug_runtime ()]]
task.work ()]]

type memory_type =
| Constant (** The tensor node does not change after initialization. *)
Expand Down
Loading

0 comments on commit 8cac8f7

Please sign in to comment.