(*
  T_s = Now T | Later <T>

  T  -- of_immediate --> Now T = T_s
  |                      |
 lift                to_later
  |                      |
  v                      v
 <T> -- of_code -->  Later <T> = T_s
     <-- to_code --
*)

(* @NOTE generated code will have cross-stage persistency *)

type ('a,'b) code_expr = { c : ('a,'b) code; a : bool }

let lift x a = { c = x; a = a }
let lift_atom x = lift x true
let lift_comp x = lift x false
let lift_const x = lift_atom .<x>.

type ('a,'b) staged = Now of 'b | Later of ('a,'b) code_expr

let is_atom = function Later c when c.a -> true | _ -> false
let is_now = function Now _ -> true | _ -> false
let is_later x = not (is_now x)

let of_expr x = Later x
let of_atom x = Later { c = x; a = true }
let of_comp x = Later { c = x; a = false }
let of_immediate x = Now x

let to_later = function Now b -> Later { c = .<b>.; a = true } | x -> x
let to_code = function Now b -> .<b>. | Later x -> x.c
let to_immediate = function | Now b -> b | _ -> failwith "to_immediate"
let to_expr = function Now b -> lift_atom .<b>. | Later x -> x

let applyc f x = to_code (f (of_expr x))
let applyc_atom f x = applyc (f (lift_atom x))
let applyc_comp f x = applyc (f (lift_comp x))
let app_imm f x = applyc (f (of_immediate x))
let apply2c f x y = to_code (f (of_expr x) (of_expr y))
let apply2c_atom f x y = apply2c (f (lift_atom x) (lift_atom y))
let apply2_comp f x y = apply2c (f (lift_comp x) (lift_comp y))
let apply2c_imm f x y = apply2c (f (of_immediate x) (of_immediate y))

(* Inlining algebra
   Val: values can be inlined
   Comp: computations can not be inlined
type 'a expr = Atom of 'a | Comp of 'b
*)

(* Unary operator *)
type ('a,'b,'c) unary = {
  unow : 'b -> 'c ;
  ulater : ('a,'b) code_expr -> ('a, 'c) code_expr
}
type ('a,'b,'c) unary_fun = ('a,'b) staged -> ('a,'c) staged

(* Plain binary operator *)
type ('a,'b,'c,'d) binary = {
  bnow : 'b -> 'c -> 'd ;
  blater : ('a,'b) code_expr -> ('a, 'c) code_expr ->
            ('a,'d) code_expr
}
type ('a,'b,'c,'d) binary_fun = ('a,'b) staged -> ('a,'c) staged -> ('a,'d) staged

(* binary operator with a unit element *)
type ('a,'b) monoid = {
  bop : ('a, 'b, 'b, 'b) binary;
  uelem : 'b
}

(* 2 monoid structures, additive and multiplicative *)
type ('a,'b) ring = {
  monp : ('a, 'b) monoid; (* additive *)
  mont : ('a, 'b) monoid;  (* multiplicative *)
}

let mk_const x = Now x

let mk_unary f = function 
  | Now x -> Now (f.unow x)
  | Later x -> Later (f.ulater x)

let mk_binary bop x y =
  match x, y with
  | (Now x), (Now y) -> Now (bop.bnow x y)
  | (Later x), (Later y) -> Later (bop.blater x y)
  | (Now x), (Later y) -> Later (bop.blater (lift_atom .<x>.) y)
  | (Later x), (Now y) -> Later (bop.blater x (lift_atom .<y>.))

(* one * a = a * one = a *)
let mk_monoid mon x y =
  match x, y with
  | (Now x), y when x = (mon.uelem) -> y
  | x, (Now y) when y = (mon.uelem) -> x
  | (Now x), (Now y) -> Now (mon.bop.bnow x y)
  | (Later x), (Later y) -> Later (mon.bop.blater x y)
  | (Now x), (Later y) -> Later (mon.bop.blater (lift_atom .<x>.) y)
  | (Later x), (Now y) -> Later (mon.bop.blater x (lift_atom .<y>.))

(* multiplicative structure of ring *)
let mk_ring rng x y =
  match x, y with
  | (Now x), (Later y) when x = (rng.monp.uelem) -> Now rng.monp.uelem
  | (Later x), (Now y) when y = (rng.monp.uelem) -> Now rng.monp.uelem
  | x, y -> mk_monoid rng.mont x y

(* ************* *)

(* Algebra for expressions?
   Staged      = Immediate | Code
   Expression  = Code | Static
   Code expr   = Atomic | Computation
   Static expr = composite type over Staged
*)
