Skip to content

Commit

Permalink
Finalize the transition to using local debug runtimes
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Aug 16, 2024
1 parent e245c50 commit e4b82ab
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 225 deletions.
198 changes: 71 additions & 127 deletions arrayjit/lib/backends.ml

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions arrayjit/lib/cc_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ let alloc_buffer ?old_buffer ~size_in_bytes () =
| Some (_old_ptr, _old_size) -> assert false
| None -> assert false

let to_buffer ?rt:_ tn ~dst ~src =
let to_buffer tn ~dst ~src =
let src = Map.find_exn src.arrays tn in
Ndarray.map2 { f2 = Ndarray.A.blit } src dst

let host_to_buffer ?rt:_ src ~dst = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
let buffer_to_host ?rt:_ dst ~src = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
let host_to_buffer src ~dst = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
let buffer_to_host dst ~src = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
let unsafe_cleanup () = Stdlib.Gc.compact ()

let is_initialized, initialize =
Expand Down
73 changes: 26 additions & 47 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -192,39 +192,30 @@ let unsafe_cleanup () =
done;
Core.Weak.fill !devices 0 len None

let%diagn_sexp from_host ?(rt : (module Minidebug_runtime.Debug_runtime) option) (ctx : context) tn
=
let%diagn_l_sexp from_host (ctx : context) tn =
match (tn, Map.find ctx.global_arrays tn) with
| { Tn.array = (lazy (Some hosted)); _ }, Some dst ->
set_ctx ctx.ctx;
(if Utils.settings.with_debug_level > 0 then
let module Debug_runtime =
(val Option.value_or_thunk rt ~default:(fun () -> (module Debug_runtime)))
in
[%log "copying", Tn.debug_name tn, "to", (dst : ctx_array), "from host"]);
if Utils.settings.with_debug_level > 0 then
[%log "copying", Tn.debug_name tn, "to", (dst : ctx_array), "from host"];
let f src = Cudajit.memcpy_H_to_D_async ~dst ~src ctx.device.stream in
Ndarray.map { f } hosted;
true
| _ -> false

let%track_sexp to_host ?(rt : (module Minidebug_runtime.Debug_runtime) option) (ctx : context)
(tn : Tn.t) =
let%track_l_sexp to_host (ctx : context) (tn : Tn.t) =
match (tn, Map.find ctx.global_arrays tn) with
| { Tn.array = (lazy (Some hosted)); _ }, Some src ->
set_ctx ctx.ctx;
(if Utils.settings.with_debug_level > 0 then
let module Debug_runtime =
(val Option.value_or_thunk rt ~default:(fun () ->
(module Debug_runtime : Minidebug_runtime.Debug_runtime)))
in
[%log "copying", Tn.debug_name tn, "at", (src : ctx_array), "to host"]);
if Utils.settings.with_debug_level > 0 then
[%log "copying", Tn.debug_name tn, "at", (src : ctx_array), "to host"];
let f dst = Cudajit.memcpy_D_to_H_async ~dst ~src ctx.device.stream in
Ndarray.map { f } hosted;
true
| _ -> false

let%track_sexp rec device_to_device ?(rt : (module Minidebug_runtime.Debug_runtime) option)
(tn : Tn.t) ~into_merge_buffer ~(dst : context) ~(src : context) =
let%track_l_sexp rec device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : context)
~(src : context) =
let memcpy ~d_arr ~s_arr =
if phys_equal dst.device.physical src.device.physical then
Cudajit.memcpy_D_to_D_async ~size_in_bytes:(Tn.size_in_bytes tn) ~dst:d_arr ~src:s_arr
Expand All @@ -243,46 +234,34 @@ let%track_sexp rec device_to_device ?(rt : (module Minidebug_runtime.Debug_runti
| Some d_arr ->
set_ctx dst.ctx;
memcpy ~d_arr ~s_arr;
(if Utils.settings.with_debug_level > 0 then
let module Debug_runtime =
(val Option.value_or_thunk rt ~default:(fun () ->
(module Debug_runtime : Minidebug_runtime.Debug_runtime)))
in
[%log
"copied",
Tn.debug_name tn,
"from",
src.label,
"at",
(s_arr : ctx_array),
"to",
(d_arr : ctx_array)]);
if Utils.settings.with_debug_level > 0 then
[%log
"copied",
Tn.debug_name tn,
"from",
src.label,
"at",
(s_arr : ctx_array),
"to",
(d_arr : ctx_array)];
true)
| Streaming ->
if phys_equal dst.device.physical src.device.physical then (
dst.device.merge_buffer <- Some (s_arr, tn);
(if Utils.settings.with_debug_level > 0 then
let module Debug_runtime =
(val Option.value_or_thunk rt ~default:(fun () ->
(module Debug_runtime : Minidebug_runtime.Debug_runtime)))
in
[%log "using merge buffer for", Tn.debug_name tn, "from", src.label]);
if Utils.settings.with_debug_level > 0 then
[%log "using merge buffer for", Tn.debug_name tn, "from", src.label];
true)
else
(* TODO: support proper streaming, but it might be difficult. *)
device_to_device ?rt tn ~into_merge_buffer:Copy ~dst ~src
device_to_device tn ~into_merge_buffer:Copy ~dst ~src
| Copy ->
set_ctx dst.ctx;
let size_in_bytes = Tn.size_in_bytes tn in
opt_alloc_merge_buffer ~size_in_bytes dst.device.physical;
memcpy ~d_arr:dst.device.physical.copy_merge_buffer ~s_arr;
dst.device.merge_buffer <- Some (dst.device.physical.copy_merge_buffer, tn);
(if Utils.settings.with_debug_level > 0 then
let module Debug_runtime =
(val Option.value_or_thunk rt ~default:(fun () ->
(module Debug_runtime : Minidebug_runtime.Debug_runtime)))
in
[%log "copied into merge buffer", Tn.debug_name tn, "from", src.label]);
if Utils.settings.with_debug_level > 0 then
[%log "copied into merge buffer", Tn.debug_name tn, "from", src.label];
true)

type code = {
Expand Down Expand Up @@ -522,7 +501,7 @@ let%track_sexp link_batch prior_context (code_batch : code_batch) =
in
(context, lowered_bindings, procs)

let to_buffer ?rt:_ _tn ~dst:_ ~src:_ = failwith "CUDA low-level: NOT IMPLEMENTED YET"
let host_to_buffer ?rt:_ _tn ~dst:_ = failwith "CUDA low-level: NOT IMPLEMENTED YET"
let buffer_to_host ?rt:_ _tn ~src:_ = failwith "CUDA low-level: NOT IMPLEMENTED YET"
let to_buffer _tn ~dst:_ ~src:_ = failwith "CUDA low-level: NOT IMPLEMENTED YET"
let host_to_buffer _tn ~dst:_ = failwith "CUDA low-level: NOT IMPLEMENTED YET"
let buffer_to_host _tn ~src:_ = failwith "CUDA low-level: NOT IMPLEMENTED YET"
let get_buffer _tn _context = failwith "CUDA low-level: NOT IMPLEMENTED YET"
20 changes: 8 additions & 12 deletions arrayjit/lib/cuda_backend.missing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ let ctx_arrays Unimplemented_ctx = Map.empty (module Tnode)

let link (Unimplemented_ctx : context) (code : code) =
let lowered_bindings = List.map ~f:(fun s -> (s, ref 0)) @@ Indexing.bound_symbols code in
let task =
Tnode.{ description = "CUDA missing: install cudajit"; work = (fun _debug_runtime () -> ()) }
in
let task = Tnode.{ description = "CUDA missing: install cudajit"; work = (fun () -> ()) } in
((Unimplemented_ctx : context), lowered_bindings, task)

let link_batch (Unimplemented_ctx : context) (code_batch : code_batch) =
Expand All @@ -31,16 +29,14 @@ let link_batch (Unimplemented_ctx : context) (code_batch : code_batch) =
in
let task =
Array.map code_batch ~f:(fun _ ->
Some
Tnode.
{ description = "CUDA missing: install cudajit"; work = (fun _debug_runtime () -> ()) })
Some Tnode.{ description = "CUDA missing: install cudajit"; work = (fun () -> ()) })
in
((Unimplemented_ctx : context), lowered_bindings, task)

let unsafe_cleanup () = ()
let from_host ?rt:_ _context _tn = false
let to_host ?rt:_ _context _tn = false
let device_to_device ?rt:_ _tn ~into_merge_buffer:_ ~dst:_ ~src:_ = false
let from_host _context _tn = false
let to_host _context _tn = false
let device_to_device _tn ~into_merge_buffer:_ ~dst:_ ~src:_ = false

type device = Unimplemented_dev [@@deriving sexp_of]
type physical_device = Unimplemented_phys_dev [@@deriving sexp_of]
Expand All @@ -58,7 +54,7 @@ let get_ctx_device Unimplemented_ctx = Unimplemented_dev
let get_name Unimplemented_dev : string = failwith "CUDA missing: install cudajit"
let to_ordinal _device = 0
let to_subordinal _device = 0
let to_buffer ?rt:_ _tn ~dst:_ ~src:_ = failwith "CUDA missing: install cudajit"
let host_to_buffer ?rt:_ _tn ~dst:_ = failwith "CUDA missing: install cudajit"
let buffer_to_host ?rt:_ _tn ~src:_ = failwith "CUDA missing: install cudajit"
let to_buffer _tn ~dst:_ ~src:_ = failwith "CUDA missing: install cudajit"
let host_to_buffer _tn ~dst:_ = failwith "CUDA missing: install cudajit"
let buffer_to_host _tn ~src:_ = failwith "CUDA missing: install cudajit"
let get_buffer _tn _context = failwith "CUDA missing: install cudajit"
23 changes: 6 additions & 17 deletions arrayjit/lib/cuda_backend.mli
Original file line number Diff line number Diff line change
Expand Up @@ -28,32 +28,21 @@ val link_batch :

val unsafe_cleanup : unit -> unit

val from_host : ?rt:(module Minidebug_runtime.Debug_runtime) -> context -> Tnode.t -> bool
val from_host : context -> Tnode.t -> bool
(** If the array is both hosted and in-context, copies from host to context. *)

val to_host : ?rt:(module Minidebug_runtime.Debug_runtime) -> context -> Tnode.t -> bool
val to_host : context -> Tnode.t -> bool
(** If the array is both hosted and in-context, copies from context to host. *)

val device_to_device :
?rt:(module Minidebug_runtime.Debug_runtime) ->
Tnode.t ->
into_merge_buffer:merge_buffer_use ->
dst:context ->
src:context ->
bool
Tnode.t -> into_merge_buffer:merge_buffer_use -> dst:context -> src:context -> bool
(** If the array is in both contexts, copies from [dst] to [src]. *)

type buffer_ptr [@@deriving sexp_of]

val to_buffer :
?rt:(module Minidebug_runtime.Debug_runtime) -> Tnode.t -> dst:buffer_ptr -> src:context -> unit

val host_to_buffer :
?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> dst:buffer_ptr -> unit

val buffer_to_host :
?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> src:buffer_ptr -> unit

val to_buffer : Tnode.t -> dst:buffer_ptr -> src:context -> unit
val host_to_buffer : Ndarray.t -> dst:buffer_ptr -> unit
val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit
val get_buffer : Tnode.t -> context -> buffer_ptr option

type physical_device
Expand Down
6 changes: 3 additions & 3 deletions arrayjit/lib/gcc_backend.gccjit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ type context = {

let ctx_arrays context = context.arrays

let to_buffer ?rt:_ tn ~dst ~src =
let to_buffer tn ~dst ~src =
let src = Map.find_exn src.arrays tn in
Ndarray.map2 { f2 = Ndarray.A.blit } src dst

let host_to_buffer ?rt:_ src ~dst = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
let buffer_to_host ?rt:_ dst ~src = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
let host_to_buffer src ~dst = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
let buffer_to_host dst ~src = Ndarray.map2 { f2 = Ndarray.A.blit } src dst

let unsafe_cleanup () =
let open Gccjit in
Expand Down
12 changes: 6 additions & 6 deletions arrayjit/lib/gcc_backend.missing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ let expected_merge_node Unimplemented_proc =

let is_in_context _node = failwith "gcc backend missing: install the optional dependency gccjit"

let to_buffer ?rt:_ _tn ~dst:_ ~src:_ =
let to_buffer _tn ~dst:_ ~src:_ =
failwith "gcc backend missing: install the optional dependency gccjit"

let host_to_buffer ?rt:_ _src ~dst:_ =
let host_to_buffer _src ~dst:_ =
failwith "gcc backend missing: install the optional dependency gccjit"

let buffer_to_host ?rt:_ _dst ~src:_ =
let buffer_to_host _dst ~src:_ =
failwith "gcc backend missing: install the optional dependency gccjit"

let alloc_buffer ?old_buffer:_ ~size_in_bytes:_ () =
Expand All @@ -35,13 +35,13 @@ let compile_batch ~names:_ ~opt_ctx_arrays:_ _bindings _codes =
let link_compiled ~merge_buffer:_ Unimplemented_ctx Unimplemented_proc =
failwith "gcc backend missing: install the optional dependency gccjit"

let from_host ?rt:_ Unimplemented_ctx _tn =
let from_host Unimplemented_ctx _tn =
failwith "gcc backend missing: install the optional dependency gccjit"

let to_host ?rt:_ Unimplemented_ctx _tn =
let to_host Unimplemented_ctx _tn =
failwith "gcc backend missing: install the optional dependency gccjit"

let device_to_device ?rt:_ _tn ~into_merge_buffer:_ ~dst:_ ~src:_ =
let device_to_device _tn ~into_merge_buffer:_ ~dst:_ ~src:_ =
failwith "gcc backend missing: install the optional dependency gccjit"

let physical_merge_buffers = false
Expand Down
17 changes: 7 additions & 10 deletions arrayjit/lib/writing_a_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ Currently, OCANNL integrates new backends via code in [Backends](backends.ml), s
```ocaml
type lowered_bindings = (static_symbol, int ref) List.Assoc.t (* in indexing.ml *)
type task = Work of ((module Debug_runtime) -> unit -> unit) (* in tnode.ml *)
type task =
| Task : { context_lifetime : 'a; description : string; work : unit -> unit; } -> task (* in tnode.ml *)
type 'context routine = {
context : 'context;
Expand Down Expand Up @@ -253,33 +254,29 @@ module type No_device_backend = sig
...
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> unit -> buffer_ptr
...
val to_buffer :
?rt:(module Minidebug_runtime.Debug_runtime) -> Tnode.t -> dst:buffer_ptr -> src:context -> unit
val to_buffer : Tnode.t -> dst:buffer_ptr -> src:context -> unit
val host_to_buffer :
?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> dst:buffer_ptr -> unit
val host_to_buffer : Ndarray.t -> dst:buffer_ptr -> unit
val buffer_to_host :
?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> src:buffer_ptr -> unit
val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit
val get_buffer : Tnode.t -> context -> buffer_ptr option
end
module type Backend = sig
include No_device_backend
...
val from_host : ?rt:(module Minidebug_runtime.Debug_runtime) -> context -> Tnode.t -> bool
val from_host : context -> Tnode.t -> bool
(** If the array is both hosted and in-context, schedules a copy from host to context and returns
true, otherwise returns false. NOTE: when run for a device, it's the caller's responsibility
to synchronize the device before the host's data is overwritten. *)
val to_host : ?rt:(module Minidebug_runtime.Debug_runtime) -> context -> Tnode.t -> bool
val to_host : context -> Tnode.t -> bool
(** If the array is both hosted and in-context, schedules a copy from context to host and returns
true, otherwise returns false. NOTE: when run for a device, it's the caller's responsibility
to synchronize the device before the host's data is read. *)
val device_to_device :
?rt:(module Minidebug_runtime.Debug_runtime) ->
Tnode.t ->
into_merge_buffer:merge_buffer_use ->
dst:context ->
Expand Down

0 comments on commit e4b82ab

Please sign in to comment.