Skip to content

Commit

Permalink
Cleanup expected_merge_node(s) after no longer verifying in `device…
Browse files Browse the repository at this point in the history
…_to_device`
  • Loading branch information
lukstafi committed Oct 11, 2024
1 parent 2858d24 commit ef76c9b
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 92 deletions.
3 changes: 0 additions & 3 deletions arrayjit/lib/backend_types.ml
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ module type No_device_backend = sig
(** Finalizes (just) the context. *)

val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> unit -> buffer_ptr
val expected_merge_node : code -> Tnode.t option
val expected_merge_nodes : code_batch -> Tnode.t option array

val compile : ?shared:bool -> ?name:string -> Indexing.unit_bindings -> Assignments.comp -> code
(** If [~shared:true] (default [false]), the backend should prefer to do more compile work in a
Expand Down Expand Up @@ -202,7 +200,6 @@ module type Lowered_no_device_backend = sig
val buffer_ptr : ctx_array -> buffer_ptr
val ctx_arrays : context -> ctx_arrays
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> unit -> buffer_ptr
val expected_merge_node : procedure -> Tnode.t option

val is_in_context : Low_level.traced_array -> bool
(** If true, the node is required to be in the contexts linked with code that uses it.
Expand Down
122 changes: 40 additions & 82 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ struct
type code = Backend.code [@@deriving sexp_of]
type code_batch = Backend.code_batch [@@deriving sexp_of]

let expected_merge_node (code : code) = Backend.expected_merge_node code
let expected_merge_nodes (codes : code_batch) = Backend.expected_merge_nodes codes
let is_dev_queue_empty state = Queue.size state.queue = 0
let is_idle stream = is_dev_queue_empty stream.state && stream.state.is_ready
let name = "multicore " ^ Backend.name
Expand Down Expand Up @@ -166,49 +164,42 @@ struct
allocated_buffer = None;
}

type context = { stream : stream; ctx : Backend.context; expected_merge_node : Tnode.t option }
[@@deriving sexp_of]

type context = { stream : stream; ctx : Backend.context } [@@deriving sexp_of]
type nonrec routine = context routine [@@deriving sexp_of]

let init stream =
{
stream;
ctx = Backend.init ~label:(name ^ " " ^ Int.to_string stream.ordinal);
expected_merge_node = None;
}
{ stream; ctx = Backend.init ~label:(name ^ " " ^ Int.to_string stream.ordinal) }

let initialize = Backend.initialize
let is_initialized = Backend.is_initialized

let finalize { stream; ctx; expected_merge_node = _ } =
let finalize { stream; ctx } =
await stream;
Backend.finalize ctx

let compile = Backend.compile
let compile_batch = Backend.compile_batch
let get_stream_name s = "stream " ^ Int.to_string s.ordinal

let link { ctx; stream; expected_merge_node = _ } code =
let link { ctx; stream } code =
let task = Backend.link ~merge_buffer:stream.merge_buffer ctx code in
{
task with
context =
{ ctx = task.context; stream; expected_merge_node = Backend.expected_merge_node code };
context = { ctx = task.context; stream };
schedule = Task.enschedule ~schedule_task ~get_stream_name stream task.schedule;
}

let link_batch { ctx; stream; expected_merge_node } code_batch =
let link_batch { ctx; stream } code_batch =
let ctx, routines = Backend.link_batch ~merge_buffer:stream.merge_buffer ctx code_batch in
let merge_nodes = Backend.expected_merge_nodes code_batch in
( { ctx; stream; expected_merge_node },
Array.mapi routines ~f:(fun i ->
Option.map ~f:(fun task ->
{
task with
context = { ctx = task.context; stream; expected_merge_node = merge_nodes.(i) };
schedule = Task.enschedule ~schedule_task ~get_stream_name stream task.schedule;
})) )
( { ctx; stream },
Array.map routines
~f:
(Option.map ~f:(fun task ->
{
task with
context = { ctx = task.context; stream };
schedule = Task.enschedule ~schedule_task ~get_stream_name stream task.schedule;
})) )

let from_host (context : context) (tn : Tnode.t) =
Option.value ~default:false
Expand Down Expand Up @@ -381,44 +372,31 @@ module Sync_backend (Backend : Backend_types.No_device_backend) : Backend_types.
type code = Backend.code [@@deriving sexp_of]
type code_batch = Backend.code_batch [@@deriving sexp_of]

let expected_merge_node (code : code) = Backend.expected_merge_node code
let expected_merge_nodes (codes : code_batch) = Backend.expected_merge_nodes codes
let all_work _stream = ()
let is_idle _stream = true
let name = "sync " ^ Backend.name
let await _stream = ()
(* let global_run_no = ref 0 *)

type context = { stream : stream; ctx : Backend.context; expected_merge_node : Tnode.t option }
[@@deriving sexp_of]

type context = { stream : stream; ctx : Backend.context } [@@deriving sexp_of]
type nonrec routine = context routine [@@deriving sexp_of]

let init stream = { stream; ctx = Backend.init ~label:name; expected_merge_node = None }
let init stream = { stream; ctx = Backend.init ~label:name }
let initialize = Backend.initialize
let is_initialized = Backend.is_initialized
let finalize { stream = _; ctx; expected_merge_node = _ } = Backend.finalize ctx
let finalize { stream = _; ctx } = Backend.finalize ctx
let compile = Backend.compile
let compile_batch = Backend.compile_batch

let link { ctx; stream; expected_merge_node = _ } code =
let link { ctx; stream } code =
let task = Backend.link ~merge_buffer:stream.merge_buffer ctx code in
{
task with
context =
{ ctx = task.context; stream; expected_merge_node = Backend.expected_merge_node code };
}
{ task with context = { ctx = task.context; stream } }

let link_batch { ctx; stream; expected_merge_node } code_batch =
let link_batch { ctx; stream } code_batch =
let ctx, routines = Backend.link_batch ~merge_buffer:stream.merge_buffer ctx code_batch in
let merge_nodes = Backend.expected_merge_nodes code_batch in
( { ctx; stream; expected_merge_node },
Array.mapi routines ~f:(fun i ->
Option.map ~f:(fun task ->
{
task with
context = { ctx = task.context; stream; expected_merge_node = merge_nodes.(i) };
})) )
( { ctx; stream },
Array.map routines
~f:(Option.map ~f:(fun task -> { task with context = { ctx = task.context; stream } })) )

let get_name stream = Int.to_string stream.subordinal

Expand Down Expand Up @@ -609,15 +587,14 @@ module Lowered_no_device_backend (Backend : Backend_types.Lowered_no_device_back
type nonrec routine = context routine [@@deriving sexp_of]

let expected_merge_node : code -> _ = function
| Postponed { lowered = Low_level.{ merge_node; _ }; _ } -> merge_node
| Compiled { proc; _ } -> Backend.expected_merge_node proc
| Postponed { lowered = Low_level.{ merge_node; _ }; _ }
| Compiled { lowered = Low_level.{ merge_node; _ }; _ } ->
merge_node

let expected_merge_nodes : code_batch -> _ = function
| Postponed { lowereds; _ } ->
| Postponed { lowereds; _ } | Compiled { lowereds; _ } ->
Array.map lowereds ~f:(fun lowered ->
Option.(join @@ map lowered ~f:(fun optim -> optim.merge_node)))
| Compiled { procs = _, procs; _ } ->
Array.map ~f:(function Some proc -> Backend.expected_merge_node proc | _ -> None) procs

let get_traced_store : code -> _ = function
| Postponed { lowered = Low_level.{ traced_store; _ }; _ }
Expand Down Expand Up @@ -721,10 +698,6 @@ module Lowered_no_device_backend (Backend : Backend_types.Lowered_no_device_back
(context, Some { context; schedule; bindings; name })
| None -> (context, None))

let to_buffer tn ~dst ~src = Backend.to_buffer tn ~dst ~src
let host_to_buffer = Backend.host_to_buffer
let buffer_to_host = Backend.buffer_to_host

let get_buffer tn context =
Map.find (Backend.ctx_arrays context) tn |> Option.map ~f:Backend.buffer_ptr
end
Expand Down Expand Up @@ -760,14 +733,10 @@ module Lowered_backend (Device : Backend_types.Lowered_backend) : Backend_types.
}
[@@deriving sexp_of]

let expected_merge_node code = code.expected_merge_node
let expected_merge_nodes code_batch = code_batch.expected_merge_nodes

type nonrec context = { ctx : context; expected_merge_node : Tnode.t option } [@@deriving sexp_of]
type nonrec routine = context routine [@@deriving sexp_of]

let work_for context = work_for context.ctx
let will_wait_for context = will_wait_for context.ctx
let work_for context = work_for context
let will_wait_for context = will_wait_for context

let compile ?shared:_ ?name bindings comp : code =
let name, lowered = lower_assignments ?name bindings comp.Assignments.asgns in
Expand Down Expand Up @@ -801,43 +770,32 @@ module Lowered_backend (Device : Backend_types.Lowered_backend) : Backend_types.
}

let link context (code : code) =
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context:context.ctx
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context:context
~from_prior_context:code.from_prior_context [| code.traced_store |];
let ctx, bindings, schedule = link context.ctx code.code in
let context, bindings, schedule = link context code.code in
let schedule =
Task.prepend schedule ~work:(fun () ->
check_merge_buffer
~scheduled_node:(scheduled_merge_node @@ get_ctx_stream context.ctx)
~code_node:(expected_merge_node code))
~scheduled_node:(scheduled_merge_node @@ get_ctx_stream context)
~code_node:code.expected_merge_node)
in
{ context = { ctx; expected_merge_node = code.expected_merge_node }; schedule; bindings; name }
{ context; schedule; bindings; name }

let link_batch context code_batch =
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context:context.ctx
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context:context
~from_prior_context:code_batch.from_prior_context code_batch.traced_stores;
let ctx, bindings, schedules = link_batch context.ctx code_batch.code_batch in
( { ctx; expected_merge_node = context.expected_merge_node },
let context, bindings, schedules = link_batch context code_batch.code_batch in
( context,
Array.mapi schedules ~f:(fun i ->
Option.map ~f:(fun schedule ->
let expected_merge_node = code_batch.expected_merge_nodes.(i) in
let schedule =
Task.prepend schedule ~work:(fun () ->
check_merge_buffer
~scheduled_node:(scheduled_merge_node @@ get_ctx_stream context.ctx)
~scheduled_node:(scheduled_merge_node @@ get_ctx_stream context)
~code_node:expected_merge_node)
in
{ context = { ctx; expected_merge_node }; schedule; bindings; name })) )

let init stream = { ctx = init stream; expected_merge_node = None }
let get_ctx_stream context = get_ctx_stream context.ctx
let finalize context = finalize context.ctx
let to_buffer tn ~dst ~src = to_buffer tn ~dst ~src:src.ctx
let get_buffer tn context = get_buffer tn context.ctx
let from_host context tn = from_host context.ctx tn
let to_host context tn = to_host context.ctx tn

let device_to_device tn ~into_merge_buffer ~dst ~src =
device_to_device tn ~into_merge_buffer ~dst:dst.ctx ~src:src.ctx
{ context; schedule; bindings; name })) )
end

module Cuda_backend : Backend_types.Backend = Lowered_backend ((
Expand Down
5 changes: 0 additions & 5 deletions arrayjit/lib/cc_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,8 @@ type procedure = {
}
[@@deriving sexp_of]

let expected_merge_node proc = proc.lowered.merge_node
let is_in_context node = Tnode.is_in_context_force node.Low_level.tn 33

let header_sep =
let open Re in
compile (seq [ str " "; opt any; str "="; str " " ])

let get_global_run_id =
let next_id = ref 0 in
fun () ->
Expand Down
1 change: 1 addition & 0 deletions arrayjit/lib/cc_backend.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include Backend_types.Lowered_no_device_backend
2 changes: 0 additions & 2 deletions arrayjit/lib/gcc_backend.gccjit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@ type procedure = {
}
[@@deriving sexp_of]

let expected_merge_node proc = proc.expected_merge_node

let is_in_context node =
Tnode.default_to_most_local node.Low_level.tn 33;
match node.tn.memory_mode with
Expand Down

0 comments on commit ef76c9b

Please sign in to comment.