
open Staged
open Basetypes
open Util

module type TUPLE =
sig
  type 'a t
  val dim : int
  val init : int -> (int -> 'a) -> 'a t
  val let_ :  ('a, 'b t) staged ->
    (('a, 'b t) staged -> ('a, 'c t) staged) -> ('a, 'c t) staged
  val proj_n : 'b t -> int -> 'b
  val proj_c : ('a, 'b t) code_expr -> int -> ('a, 'b) code_expr
  val to_code : ('a, 'c) staged t -> ('a, 'c t) code
  val to_expr : ('a, 'c) staged t -> ('a, 'c t) code_expr
  val to_imm : ('a, 'c) staged t -> 'c t

  val of_list_n : 'b list -> 'b t
  val of_list_c : ('a, 'b) code_expr list -> ('a, 'b t) code_expr
  val to_list_n : 'a t -> 'a list
  val to_list_c : ('a, 'b t) code_expr -> ('a, 'b) code_expr list
  (* val map : ('a t -> 'b t) -> 'a t -> 'b t *)
  val map_n : ('a, 'b, 'c) unary_fun -> 'b t -> 'c t
  val map_c : ('a, 'b, 'c) unary_fun ->
                  ('a, 'b t) code_expr -> ('a, 'c t) code_expr

  val mapi_n : (int -> ('a, 'b) staged -> ('a, 'c) staged) -> 'b t -> 'c t
  val mapi_c : (int -> ('a, 'b) staged -> ('a, 'c) staged) ->
                  ('a, 'b t) code_expr -> ('a, 'c t) code_expr

  val map2_n : ('a, 'b, 'c, 'd) binary_fun -> 'b t -> 'c t -> 'd t
  val map2_c : ('a, 'b, 'c, 'd) binary_fun ->  ('a, 'b t) code_expr ->
                  ('a, 'c t) code_expr -> ('a, 'd t) code_expr

  val fold_n : ('a, 'b, 'c, 'c) binary_fun -> 'c -> 'b t -> 'c
  val fold_c : ('a, 'b, 'c, 'c) binary_fun -> 'c ->
                  ('a, 'b t) code_expr -> ('a, 'c) code_expr

  val mapfold_n : ('a, 'b, 'b) unary_fun -> ('a, 'b, 'c, 'c) binary_fun ->
                    'c -> 'b t -> 'c
  val mapfold_c : ('a, 'b, 'b) unary_fun -> ('a, 'b, 'c, 'c) binary_fun ->
                    'c -> ('a, 'b t) code_expr -> ('a, 'c) code_expr

  val map2fold_n : ('a, 'b, 'c, 'd) binary_fun -> ('a, 'd, 'e, 'e) binary_fun ->
                    'e -> 'b t -> 'c t -> 'e
  val map2fold_c : ('a, 'b, 'c, 'd) binary_fun -> ('a, 'd, 'e, 'e) binary_fun ->
                    'e -> ('a, 'b t) code_expr ->
		      ('a, 'c t) code_expr -> ('a, 'e) code_expr
end

(* ******************************************** *)
(* ******************************************** *)
(* 1D tuple *)

module Tuple1D =
struct
  type 'a t = 'a
  let dim = 1
  let init i f = (f 0)
  let proj_n t _ = t
  let proj_c t i = t
  let to_code t = Staged.to_code t
  let to_expr t = Staged.to_expr t
  (* @TODO implement PRECOND \forall is_now *)
  let to_imm t = Staged.to_immediate t
  let of_list_n l = List.hd l
  let of_list_c l = List.hd l
  let to_list_n t = [ t ]
  let to_list_c t = [ t ]
  (* @TODO PRECOND \forall x. is_now f(x) *)

  let map_n f t = Staged.to_immediate (f (Now t))
  let map_c f t = Staged.to_expr (f (of_expr t))

  let mapi_n f t = Staged.to_immediate (f 0 (Now t))
  let mapi_c f t = Staged.to_expr (f 0 (of_expr t))

  let map2_n f t t' = Staged.to_immediate (f (Now t) (Now t'))
  let map2_c f t t' = Staged.to_expr (f (of_expr t) (of_expr t'))

  let fold_n f z t = Staged.to_immediate (f (Now t) (Now z))
  let fold_c f z t = Staged.to_expr (f (of_expr t) (Now z))

  let mapfold_n m f z t =
    Staged.to_immediate (f (m (Now t)) (Now z))
  let mapfold_c m f z t =
    Staged.to_expr (f (m (of_expr t)) (Now z))

  let map2fold_n m f z t t' =
    let t = Now t and t' = Now t' in
    Staged.to_immediate (f (m t t') (Now z))
  let map2fold_c m f z t t' =
    let t = of_expr t and t' = of_expr t' in
    Staged.to_expr (f (m t t') (Now z))
  let let_ = Code.let_
end

(* ******************************************** *)
(* ******************************************** *)
(* ******************************************** *)

(* Implementation *)

type 'a rec2_type = { c0 : 'a ; c1 : 'a }
let rec2_create x y = { c0 = x; c1 = y }
let rec2_create_c x y = .<{ c0 = .~x; c1 = .~y }>.
let rec2_get r i = if i = 0 then r.c0 else r.c1
let rec2_get_c r i = if i = 0 then .<(.~r).c0>. else .<(.~r).c1>.
let rec2_fst r = rec2_get r 0
let rec2_snd r = rec2_get r 1
let rec2_fst_c r = rec2_get_c r 0
let rec2_snd_c r = rec2_get_c r 1

module Record2D =
struct
  type 'a t = 'a rec2_type
  let dim = 2
  let init i f = rec2_create (f 0) (f 1)
  let proj_n t i = rec2_get t i
  (* @TODO fail on value =/= 1 or 2 *)
  let proj_c t i =
    let x = (rec2_get_c t.c i) in
    if t.a then lift_atom x else lift_comp x
  let to_code t =
    rec2_create_c (Staged.to_code (rec2_get t 0))
                  (Staged.to_code (rec2_get t 1))
  let to_expr t = lift_comp (to_code t)
  (* @TODO implement PRECOND \forall is_now *)
  let to_imm t =
    rec2_create (Staged.to_immediate (rec2_get t 0))
                (Staged.to_immediate (rec2_get t 1))
  let of_list_n l = rec2_create (List.nth l 0) (List.nth l 1)
  let of_list_c l =
    let x = (List.nth l 0) and y = (List.nth l 1) in
    lift_comp (rec2_create_c x.c y.c)
  let to_list_n t = [ (rec2_get t 0); (rec2_get t 1) ]
  let to_list_c t = [ lift (rec2_get_c t.c 0) t.a;
		      lift (rec2_get_c t.c 1) t.a ]
  (* @TODO PRECOND \forall x. is_now f(x) *)
  let map_n f t = of_list_n [Staged.to_immediate (f (Now (rec2_fst t)));
			     Staged.to_immediate (f (Now (rec2_snd t)))]
  let map_c f t =
    Code.letc_ t (fun t ->
      let x = Staged.of_comp (rec2_fst_c t.c)
      and y = Staged.of_comp (rec2_snd_c t.c) in
      let x' = Staged.to_code (f x)
      and y' = Staged.to_code (f y) in 
      of_list_c [lift_comp x'; lift_comp y'])

  let mapi_n f t = of_list_n [Staged.to_immediate (f 0 (Now (rec2_fst t)));
		              Staged.to_immediate (f 1 (Now (rec2_snd t)))]
  let mapi_c f t =
    Code.letc_ t (fun t ->
      let x = Staged.of_comp (rec2_fst_c t.c)
      and y = Staged.of_comp (rec2_snd_c t.c) in
      let x' = Staged.to_code (f 0 x)
      and y' = Staged.to_code (f 1 y) in 
      of_list_c [lift_comp x'; lift_comp y'])

  let map2_n f t t' =
    of_list_n [ Staged.to_immediate (f (Now t.c0) (Now t'.c0));
		Staged.to_immediate (f (Now t.c1) (Now t'.c1)) ]
  let map2_c f t t' =
    Code.letc_ t (fun t ->
     Code.letc_ t' (fun t' ->
      let x = Staged.of_comp (rec2_fst_c t.c)
      and y = Staged.of_comp (rec2_snd_c t.c)
      and x' = Staged.of_comp (rec2_fst_c t'.c)
      and y' = Staged.of_comp (rec2_snd_c t'.c) in
      let c0 = Staged.to_code (f x x')
      and c1 = Staged.to_code (f y y') in
      lift_comp .<{ c0 = .~c0; c1 = .~c1 }>.))

  let fold_n f z t =
    let x = Now (rec2_fst t) and y = Now (rec2_snd t) in
    Staged.to_immediate (f y (f x (Now z)))
  let fold_c f z t =
    Code.letc_ t (fun t ->
      let x = Staged.of_comp (rec2_fst_c t.c)
      and y = Staged.of_comp (rec2_snd_c t.c) in
      let r = Staged.to_code (f y (f x (Now z))) in
      lift_comp r)

  let mapfold_n m f z t =
    let x = rec2_fst t and y = rec2_snd t in
    let r = f (m (Now x)) (f (m (Now y)) (Now z)) in
    Staged.to_immediate r
  let mapfold_c m f z t =
    Code.letc_ t (fun t ->
      let x = Staged.of_comp (rec2_fst_c t.c)
      and y = Staged.of_comp (rec2_snd_c t.c) in
      let r = f (m x) (f (m y) (Now z))
     in lift_comp (Staged.to_code r))

  let map2fold_n m f z t t' = fold_n f z (map2_n m t t')
  let map2fold_c m f z t t' =
    Code.letc_ t (fun t -> Code.letc_ t' (fun t' ->
     let x = Staged.of_comp (rec2_fst_c t.c)
     and y = Staged.of_comp (rec2_snd_c t.c)
     and x' = Staged.of_comp (rec2_fst_c t'.c)
     and y' = Staged.of_comp (rec2_snd_c t'.c) in
     let r = f (m x x') (f (m y y') (Now z)) in
     lift_comp (Staged.to_code r)))
  let let_ ce exp = match ce with
    | Now _ -> exp ce
    | Later c -> of_comp .< let t = .~(c.c) in
                   .~(Staged.to_code (exp (of_atom .<t>.))) >.
end

module Pair2D =
struct
  type 'a t = 'a * 'a
  let dim = 2
  let init i f = (f 0), (f 1)
  let proj_n (a,b) i = if i = 0 then a else b
  (* @TODO fail on value =/= 1 or 2 *)
  let proj_c t i =
    if i = 0 then lift_comp .< fst .~(t.c) >.
    else lift_comp .< snd .~(t.c) >.
  let to_code (a,b) = .<.~(Staged.to_code a), .~(Staged.to_code b)>.
  let to_expr t = lift_comp (to_code t)
  let let_ ce exp = match ce with
    | Now _ -> exp ce
    | Later c -> of_comp .< let t = .~(c.c) in
                   .~(Staged.to_code (exp (of_atom .<t>.))) >.
  (* @TODO implement PRECOND \forall is_now *)
  let to_imm (a,b) = (Staged.to_immediate a), (Staged.to_immediate b)
  let of_list_n l = (List.nth l 0), (List.nth l 1)
  let of_list_c l =
    let x = List.nth l 0
    and y = List.nth l 1 in
    lift_comp .< .~(x.c), .~(y.c) >.
  let to_list_n t = [ fst t; snd t ]
  let to_list_c t = [ lift (proj_c t 0).c t.a;
		      lift (proj_c t 1).c t.a ]
  (* @TODO PRECOND \forall x. is_now f(x) *)
  let map_n f (a,b) = of_list_n [Staged.to_immediate (f (Now a));
			     Staged.to_immediate (f (Now b))]
  let map_c f t = lift_comp .< let x,y = .~(t.c) in
      .~(Staged.to_code (f (of_atom .<x>.))),
      .~(Staged.to_code (f (of_atom .<y>.))) >.

  let mapi_n f (a,b) = of_list_n [Staged.to_immediate (f 0 (Now a));
				  Staged.to_immediate (f 1 (Now b))]
  let mapi_c f t = lift_comp .< let x,y = .~(t.c) in
      .~(Staged.to_code (f 0 (of_atom .<x>.))),
      .~(Staged.to_code (f 1 (of_atom .<y>.))) >.

  let map2_n f (a,b) (a',b') =
    Staged.to_immediate (f (Now a) (Now a')),
    Staged.to_immediate (f (Now b) (Now b'))
  (* @TODO artifact will remain because there is no general
     mechanism for let2 or letN variables *)
  let map2_c f t t' = lift_comp
    .< let x,y = .~(t.c) and x',y' = .~(t'.c) in
      .~(Staged.to_code (f (of_atom .<x>.) (of_atom .<x'>.))),
      .~(Staged.to_code (f (of_atom .<y>.) (of_atom .<y'>.))) >.

  let fold_n f z (a,b) =
    let x = Now a and y = Now b in
    Staged.to_immediate (f y (f x (Now z)))
  let fold_c f z t = lift_comp
    .< let x,y = .~(t.c) in
        .~(let x = of_atom .<x>. and y = of_atom .<y>. in
           Staged.to_code (f x (f y (Now z)))) >.

  let mapfold_n m f z (x,y) =
    let r = f (m (Now x)) (f (m (Now y)) (Now z)) in
    Staged.to_immediate r
  let mapfold_c m f z t = lift_comp
    .< let x,y = .~(t.c) in
      .~(let x = of_atom .<x>. and y = of_atom .<y>. in
         Staged.to_code (f (m x) (f (m y) (Now z)))) >.

  let map2fold_n m f z t t' = fold_n f z (map2_n m t t')
  let map2fold_c m f z t t' = lift_comp
    (* .< let x,y = .~(t.c) and x',y' = .~(t'.c) in
       .~(let x = of_atom .<x>. and y = of_atom .<y>.
          and x' = of_atom .<x'>. and y' = of_atom .<y'>. in
	  let r = f (m x x') (f (m y y') (Now z)) in
	  Staged.to_code r) >. *)
    .< .~(let x = .<fst .~(t.c)>. and y = .<snd .~(t.c)>.
          and x' = .<fst .~(t'.c)>. and y' = .<snd .~(t'.c)>. in
          let x = of_atom x and y = of_atom y
          and x' = of_atom x' and y' = of_atom y' in
	  let r = f (m x x') (f (m y y') (Now z)) in
	  Staged.to_code r) >.
end

(* Tuple 3D *)

(* Implementation *)

type 'a rec3_type = { x : 'a ; y : 'a ; z : 'a }
let rec3_create x y z = { x = x; y = y; z = z }
let rec3_create_c x y z = .<{ x = .~x; y = .~y; z = .~z }>.
let rec3_get r i = [| r.x; r.y; r.z |].(i)
let rec3_get_c r i = [| .<(.~r).x>.; .<(.~r).y>.; .<(.~r).z>. |].(i)

module Record3D =
struct
  type 'a t = 'a rec3_type
  let dim = 3
  let init i f = rec3_create (f 0) (f 1) (f 2)
  let proj_n t i = rec3_get t i
  let proj_c t i =
    let x = (rec3_get_c t.c i) in
    if t.a then lift_atom x else lift_comp x
  let to_code t =
    rec3_create_c (Staged.to_code (rec3_get t 0))
      (Staged.to_code (rec3_get t 1)) (Staged.to_code (rec3_get t 2))
  let to_expr t = lift_comp (to_code t)
  let let_ ce exp = match ce with
    | Now _ -> exp ce
    | Later c -> of_comp .< let t = .~(c.c) in
                   .~(Staged.to_code (exp (of_atom .<t>.))) >.
  (* @TODO implement PRECOND \forall is_now *)
  let to_imm t =
    rec3_create (Staged.to_immediate (rec3_get t 0))
                (Staged.to_immediate (rec3_get t 1))
                (Staged.to_immediate (rec3_get t 2))
  let of_list_n l = rec3_create (List.nth l 0) (List.nth l 1) (List.nth l 2)
  let of_list_c l =
    let x = (List.nth l 0) and y = (List.nth l 1)
    and z = (List.nth l 2) in
    lift_comp (rec3_create_c x.c y.c z.c)
  let to_list_n t = [ (rec3_get t 0); (rec3_get t 1); (rec3_get t 2) ]
  let to_list_c t = [ lift (rec3_get_c t.c 0) t.a;
		      lift (rec3_get_c t.c 1) t.a;
		      lift (rec3_get_c t.c 2) t.a ]
  (* @TODO PRECOND \forall x. is_now f(x) *)
  let map_n f t = of_list_n [Staged.to_immediate (f (Now (rec3_get t 0)));
			     Staged.to_immediate (f (Now (rec3_get t 1)));
			     Staged.to_immediate (f (Now (rec3_get t 2)))]
  let map_c f t =
    Code.letc_ t (fun t ->
      let x = Staged.of_comp (rec3_get_c t.c 0)
      and y = Staged.of_comp (rec3_get_c t.c 1) 
      and z = Staged.of_comp (rec3_get_c t.c 2) in
      let x' = Staged.to_code (f x)
      and y' = Staged.to_code (f y)
      and z' = Staged.to_code (f z) in
      of_list_c [lift_comp x'; lift_comp y'; lift_comp z'])

  let mapi_n f t = of_list_n [Staged.to_immediate (f 0 (Now (rec3_get t 0)));
		              Staged.to_immediate (f 1 (Now (rec3_get t 1)));
		              Staged.to_immediate (f 2 (Now (rec3_get t 2)))]
  let mapi_c f t =
    Code.letc_ t (fun t ->
      let x = Staged.of_comp (rec3_get_c t.c 0)
      and y = Staged.of_comp (rec3_get_c t.c 1)
      and z = Staged.of_comp (rec3_get_c t.c 2) in
      let x' = Staged.to_code (f 0 x)
      and y' = Staged.to_code (f 1 y)
      and z' = Staged.to_code (f 2 z) in
      of_list_c [lift_comp x'; lift_comp y'; lift_comp z'])

  let map2_n f t t' =
    of_list_n [ Staged.to_immediate (f (Now t.x) (Now t'.x));
		Staged.to_immediate (f (Now t.y) (Now t'.y));
		Staged.to_immediate (f (Now t.z) (Now t'.z)) ]
  let map2_c f t t' =
    Code.letc_ t (fun t ->
     Code.letc_ t' (fun t' ->
      let x = Staged.of_comp (rec3_get_c t.c 0)
      and y = Staged.of_comp (rec3_get_c t.c 1)
      and z = Staged.of_comp (rec3_get_c t.c 2)
      and x' = Staged.of_comp (rec3_get_c t'.c 0)
      and y' = Staged.of_comp (rec3_get_c t'.c 1)
      and z' = Staged.of_comp (rec3_get_c t'.c 2) in
      let c0 = Staged.to_code (f x x')
      and c1 = Staged.to_code (f y y')
      and c2 = Staged.to_code (f z z') in
      lift_comp .<{ x = .~c0; y = .~c1; z = .~c2 }>.))

  let fold_n f zer t =
    let x = Now (rec3_get t 0) and y = Now (rec3_get t 1)
    and z = Now (rec3_get t 2) in
    Staged.to_immediate (f z (f y (f x (Now zer))))
  let fold_c f zer t =
    Code.letc_ t (fun t ->
      let x = Staged.of_comp (rec3_get_c t.c 0)
      and y = Staged.of_comp (rec3_get_c t.c 1)
      and z = Staged.of_comp (rec3_get_c t.c 1) in
      let r = Staged.to_code (f z (f y (f x (Now zer)))) in
      lift_comp r)

  let mapfold_n m f zer t =
    let x = rec3_get t 0 and y = rec3_get t 1
    and z = rec3_get t 2 in
    let r = f (m (Now z)) (f (m (Now x)) (f (m (Now y)) (Now zer))) in
    Staged.to_immediate r
  let mapfold_c m f zer t =
    Code.letc_ t (fun t ->
      let x = Staged.of_comp (rec3_get_c t.c 0)
      and y = Staged.of_comp (rec3_get_c t.c 1)
      and z = Staged.of_comp (rec3_get_c t.c 2) in
      let r = f (m z) (f (m x) (f (m y) (Now zer)))
     in lift_comp (Staged.to_code r))

  let map2fold_n m f zer t t' = fold_n f zer (map2_n m t t')
  let map2fold_c m f zer t t' =
    Code.letc_ t (fun t -> Code.letc_ t' (fun t' ->
     let x = Staged.of_comp (rec3_get_c t.c 0)
     and y = Staged.of_comp (rec3_get_c t.c 1)
     and z = Staged.of_comp (rec3_get_c t.c 2)
     and x' = Staged.of_comp (rec3_get_c t'.c 0)
     and y' = Staged.of_comp (rec3_get_c t'.c 1)
     and z' = Staged.of_comp (rec3_get_c t'.c 2) in
     let r = f (m z z') (f (m x x') (f (m y y') (Now zer))) in
     lift_comp (Staged.to_code r)))
end

module type TUPLE_STAGED =
sig
  type 'a t
  val dim : int
  val init : int -> (int -> 'a) -> 'a t
  val let_ :  ('a, 'b t) staged ->
    (('a, 'b t) staged -> ('a, 'c t) staged) -> ('a, 'c t) staged

  (* val map : ('a t -> 'b t) -> 'a t -> 'b t *)
  val map : ('a, 'b, 'c) unary_fun -> ('a, 'b t) staged -> ('a, 'c t) staged
  val mapi : (int -> ('a, 'b) staged -> ('a, 'c) staged) -> ('a, 'b t) staged -> ('a, 'c t) staged
  val map2 : ('a, 'b, 'c, 'd) binary_fun -> ('a, 'b t) staged -> ('a, 'c t) staged -> ('a, 'd t) staged
  val fold : ('a, 'b, 'c, 'c) binary_fun -> 'c -> ('a, 'b t) staged -> ('a, 'c) staged

  val mapfold : ('a, 'b, 'd) unary_fun -> ('a, 'd, 'c, 'c) binary_fun ->
                    'c -> ('a, 'b t) staged -> ('a, 'c) staged

  val map2fold : ('a, 'b, 'c, 'd) binary_fun -> ('a, 'd, 'e, 'e) binary_fun ->
                    'e -> ('a, 'b t) staged -> ('a, 'c t) staged -> ('a, 'e) staged
end

module Pair2DS =
struct
  type 'a t = 'a Pair2D.t
  let dim = 2
  let init i f = (f 0), (f 1)
  let let_ ce exp = match ce with
    | Now _ -> exp ce
    | Later c -> of_comp .< let t = .~(c.c) in
                   .~(Staged.to_code (exp (of_atom .<t>.))) >.
  let map f v =
    let unow t = Pair2D.map_n f t
    and ulater t = (Pair2D.map_c f t) in
    mk_unary { unow = unow; ulater = ulater } v

  let map2 f v v' =
    let bnow v v' = Pair2D.map2_n f v v'
    and blater v v' = (Pair2D.map2_c f v' v) in
    mk_binary { bnow = bnow; blater = blater } v v'

  let fold f z v =
    let unow v = Pair2D.fold_n f z v
    and ulater v = (Pair2D.fold_c f z v) in
    mk_unary { unow = unow; ulater = ulater } v

  let mapfold m f z v =
    let unow v = Pair2D.mapfold_n m f z v
    and ulater v = (Pair2D.mapfold_c m f z v) in
    mk_unary { unow = unow; ulater = ulater } v

  let map2fold m f z v v' =
    let bnow v v' = Pair2D.map2fold_n m f z v v'
    and blater v v' = (Pair2D.map2fold_c m f z v v') in
    mk_binary { bnow = bnow; blater = blater } v v'

end

