Skip to content

Commit

Permalink
Debug kernel inputs when with_debug_level > 1
Browse files Browse the repository at this point in the history
Missing for the gccjit backend.
Also print host addresses in print / log_accesssible_headers.
  • Loading branch information
lukstafi committed Jul 27, 2024
1 parent c1f28dc commit 52162a0
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 18 deletions.
36 changes: 30 additions & 6 deletions arrayjit/lib/backend_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@ end

module Tn = Tnode

let get_c_ptr nd =
let prec = Ndarray.get_prec nd in
let f arr = Ops.ptr_to_string (Ctypes.bigarray_start Ctypes_static.Genarray arr) prec in
Ndarray.(map { f } nd)

module C_syntax (B : sig
val for_lowereds : Low_level.optimized array

Expand Down Expand Up @@ -111,7 +106,7 @@ struct
| false, _, _, true, Some (Hosted _, _), true ->
(* In-context nodes to read directly from host would be error prone. *)
let nd = Option.value_exn ~here:[%here] @@ Lazy.force node.tn.array in
fprintf ppf "#define %s (%s)@," (get_ident node.tn) (get_c_ptr nd);
fprintf ppf "#define %s (%s)@," (get_ident node.tn) (Ndarray.c_ptr_to_string nd);
Hash_set.add is_global node.tn
| _ -> ()));
fprintf ppf "@,@]";
Expand Down Expand Up @@ -326,6 +321,35 @@ struct
(* FIXME: we should also close the file. *)
if (not (List.is_empty log_file)) && not B.logs_to_stdout then
fprintf ppf {|FILE* log_file = fopen(log_file_name, "w");@ |};
if Utils.settings.debug_log_from_routines && Utils.settings.with_debug_level > 1 then (
fprintf ppf "/* Debug initial parameter state. */@ ";
List.iter
~f:(function
| p_name, Merge_buffer ->
if B.logs_to_stdout then
fprintf ppf {|@[<7>printf(@[<h>"%s%%d: %s = %%p\n",@] log_id, (void*)merge_buffer);@]@ |}
!Utils.captured_log_prefix p_name
else
fprintf ppf {|@[<7>fprintf(log_file,@ @[<h>"%s = %%p\n",@] (void*)merge_buffer);@]@ |}
p_name
| _, Log_file_name -> ()
| p_name, Param_ptr tn ->
if B.logs_to_stdout then
fprintf ppf {|@[<7>printf(@[<h>"%s%%d: %s = %%p\n",@] log_id, (void*)%s);@]@ |}
!Utils.captured_log_prefix p_name
@@ get_ident tn
else
fprintf ppf {|@[<7>fprintf(log_file,@ @[<h>"%s = %%p\n",@] (void*)%s);@]@ |} p_name
@@ get_ident tn
| p_name, Static_idx s ->
if B.logs_to_stdout then
fprintf ppf {|@[<7>printf(@[<h>"%s%%d: %s = %%d\n",@] log_id, %s);@]@ |}
!Utils.captured_log_prefix p_name
@@ Indexing.symbol_ident s.Indexing.static_symbol
else
fprintf ppf {|@[<7>fprintf(log_file,@ @[<h>"%s = %%d\n",@] %s);@]@ |} p_name
@@ Indexing.symbol_ident s.Indexing.static_symbol)
params);
fprintf ppf "/* Local declarations and initialization. */@ ";
Hashtbl.iteri traced_store ~f:(fun ~key:tn ~data:node ->
if not (Tn.is_virtual_force tn 333 || B.is_in_context node || Hash_set.mem is_global tn)
Expand Down
8 changes: 3 additions & 5 deletions arrayjit/lib/cc_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ let optimization_level () =

let compiler_command () = Utils.get_global_arg ~default:"cc" ~arg_name:"cc_backend_compiler_command"

(** Currently unused, backend behaves as if [config] is always [`Physical_devices_only]. *)

module Tn = Tnode

type ctx_array = Ndarray.t [@@deriving sexp_of]
Expand Down Expand Up @@ -47,7 +45,7 @@ let to_buffer ?rt:_ tn ~dst ~src =

let host_to_buffer ?rt:_ src ~dst = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
let buffer_to_host ?rt:_ dst ~src = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
let unsafe_cleanup () = ()
let unsafe_cleanup () = Stdlib.Gc.compact ()

let is_initialized, initialize =
let initialized = ref false in
Expand Down Expand Up @@ -105,7 +103,7 @@ let%track_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_
type nonrec ctx_array = ctx_array

let opt_ctx_arrays = opt_ctx_arrays
let hardcoded_context_ptr = Some Backend_utils.get_c_ptr
let hardcoded_context_ptr = Some Ndarray.c_ptr_to_string
let is_in_context = is_in_context
let host_ptrs_for_readonly = true
let logs_to_stdout = false
Expand Down Expand Up @@ -156,7 +154,7 @@ let%track_sexp compile_batch ~names ~opt_ctx_arrays bindings
type nonrec ctx_array = ctx_array

let opt_ctx_arrays = opt_ctx_arrays
let hardcoded_context_ptr = Some Backend_utils.get_c_ptr
let hardcoded_context_ptr = Some Ndarray.c_ptr_to_string
let is_in_context = is_in_context
let host_ptrs_for_readonly = true
let logs_to_stdout = false
Expand Down
5 changes: 5 additions & 0 deletions arrayjit/lib/ndarray.ml
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ let retrieve_flat_values arr =

(** {2 *** Printing ***} *)

let c_ptr_to_string nd =
let prec = get_prec nd in
let f arr = Ops.ptr_to_string (Ctypes.bigarray_start Ctypes_static.Genarray arr) prec in
map { f } nd

(** Dimensions to string, ["x"]-separated, e.g. 1x2x3 for batch dims 1, input dims 3, output dims 2.
Outputs ["-"] for empty dimensions. *)
let int_dims_to_string ?(with_axis_numbers = false) dims =
Expand Down
4 changes: 3 additions & 1 deletion arrayjit/lib/tnode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,9 @@ let header arr =
if Lazy.is_val arr.array then
match arr.array with
| (lazy None) -> "<not-hosted>"
| (lazy (Some nd)) -> Int.to_string_hum @@ Nd.size_in_bytes nd
| (lazy (Some nd)) ->
let size = Int.to_string_hum @@ Nd.size_in_bytes nd in
if Utils.settings.with_debug_level > 0 then size ^ " @ " ^ Nd.c_ptr_to_string nd else size
else "<not-in-yet>"
in
let repeating_nograd_idents = Hashtbl.create ~size:1 (module String) in
Expand Down
26 changes: 20 additions & 6 deletions bin/moons_benchmark.ml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
CDSL.virtualize_settings.max_visits <- inlining_cutoff;
Tensor.default_value_prec := value_prec;
Tensor.default_grad_prec := grad_prec;
Utils.settings.with_debug_level <- 3;
Utils.settings.output_debug_files_in_run_directory <- true;
Utils.settings.debug_log_from_routines <- true;
Rand.init (* seed *) 0;
Expand All @@ -38,8 +39,9 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
(* let hid_dim = 4 in *)
let len = batch_size * 20 in
(* let epochs = 100 in *)
(* let epochs = 10 in *)
let epochs = 5 in
(* let epochs = 20 in *)
(* let epochs = 5 in *)
let epochs = 1 in
let init_lr = 0.1 in
let noise () = Rand.float_range (-0.1) 0.1 in
let moons_flat =
Expand Down Expand Up @@ -74,6 +76,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
epoch_loss
in
let module Backend = (val backend) in
Backend.initialize Train.BT.Most_parallel_devices;
let inputs, outputs, model_result, infer_callback, batch_losses, epoch_losses, learning_rates =
Train.example_train_loop ~seed ~batch_size ~init_lr ~max_num_devices:num_devices ~data_len:len
~epochs ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn ~weight_decay
Expand Down Expand Up @@ -154,12 +157,12 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
Backend.unsafe_cleanup ();
result

let () =
let _suspend () =
ignore
@@ classify_moons ~seed:0 ~on_device:true ~inlining_cutoff:3 ~num_devices:8 ~batch_size:16
~backend_name:"gccjit" ~value_prec:CDSL.single ~grad_prec:CDSL.double ()

let benchmarks =
let _cpu_benchmarks =
List.concat_map [ 0; 1; 2; 3 ] ~f:(fun inlining_cutoff ->
List.concat_map [ 1; 2; 5; 8; 10; 16 (* ; 20 *) ] ~f:(fun num_devices ->
List.concat_map [ 120; 160 (* ; 320; 640; 1280 *) ] ~f:(fun batch_size ->
Expand All @@ -170,6 +173,17 @@ let benchmarks =
~batch_size ~backend_name ~value_prec:CDSL.single ~grad_prec:CDSL.single;
])))))

let cuda_benchmarks =
List.concat_map [ 0; (* 1; 2; *) 3 ] ~f:(fun inlining_cutoff ->
List.concat_map [1; 2; 5; 8; 10; 16; 20; 30 (* *; 32; 40; 64 *) ] ~f:(fun num_devices ->
List.concat_map [ 120; 160 (* ; 320; 640; 1280 *) ] ~f:(fun batch_size ->
List.concat_map [ 0; 1 (* ; 2; 3; 4 *) ] ~f:(fun seed ->
List.concat_map [ (* "gccjit" ; *) "cc" (* ; "cuda" *) ] ~f:(fun backend_name ->
[
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_devices
~batch_size ~backend_name ~value_prec:CDSL.single ~grad_prec:CDSL.single;
])))))

(* let time_of = function PrintBox_utils.Benchmark { time_in_sec; _ } -> time_in_sec let nth_best
nth bench = let results = List.init 5 ~f:(fun seed -> bench ~seed ()) in let sorted = List.sort
results ~compare:(fun r1 r2 -> Float.compare (time_of r1) (time_of r2)) in List.nth_exn sorted
Expand All @@ -186,7 +200,7 @@ let _suspended () =
Stdio.stdout *)

let benchmark () =
List.map benchmarks ~f:(fun bench -> bench ())
List.map cuda_benchmarks ~f:(fun bench -> bench ())
|> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout

let _suspended () = benchmark ()
let () = benchmark ()
3 changes: 3 additions & 0 deletions lib/train.ml
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,9 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
f ~at_batch:!batch_ref ~at_step:!step_ref ~learning_rate:learning_rate.@[0] ~batch_loss
~epoch_loss:!epoch_loss))
in
if Utils.settings.with_debug_level > 1 then (
Stdlib.Printf.printf "\nTraining...\n%!";
Tn.log_accessible_headers ());
for epoch = 0 to epochs - 1 do
epoch_loss := 0.;
update ();
Expand Down

0 comments on commit 52162a0

Please sign in to comment.