open Staged
open Basetypes
open Algebra

module Float_Set (* : SET *) =
struct
  type n = float
  type 'a ns = ('a, n) staged
  type 'a rel = 'a ns -> 'a ns -> 'a Bool.b
  type 'a bin = 'a ns -> 'a ns -> 'a ns
  let to_string_b = 
      { unow = (fun x -> string_of_float x);
        ulater = (fun x -> lift_comp .< string_of_float .~(x.c)>.) }
  let to_string_s v = mk_unary to_string_b v

  let eq_b = { bnow = (fun x y -> x=y);
               blater = (fun x y -> lift_comp .<.~(x.c) = .~(y.c)>.) }
  let eq_s a b = mk_binary eq_b a b
  let neq_b = { bnow = (fun x y -> x <> y);
                blater = (fun x y -> lift_comp .<.~(x.c) <> .~(y.c)>.) }
  let neq_s a b = mk_binary neq_b a b
  let eq_tol a b = mk_binary
    {
      bnow = (fun x y -> abs_float (x -. y) < 1e-6);
      blater = (fun x y -> lift_comp .< abs_float (.~(x.c) -. .~(y.c)) < 1e-6 >.)
    } a b
  let of_int = float_of_int
end

module Float_Order (* : ORDER *) =
struct
  include Float_Set
  type t = n
  type 'a t_s = 'a ns

  let bot = neg_infinity
  let top = infinity
  
  let eq = eq_s
  let neq = neq_s

  let compare_b =
    { bnow = (fun x y -> Pervasives.compare x y);
      blater = (fun x y -> lift_comp .< Pervasives.compare .~(x.c) .~(y.c) >.) }
  let compare_s x y = mk_binary compare_b x y

  let lt_b = { bnow = (fun x y -> x < y);
               blater = (fun x y -> lift_comp .< .~(x.c) < .~(y.c) >.) }
  let lt_s x y = mk_binary lt_b x y

  let le_b = { bnow = (fun x y -> x <= y);
               blater = (fun x y -> lift_comp .< .~(x.c) <= .~(y.c) >.) }
  let le_s x y = mk_binary le_b x y

  let gt_b = { bnow = (fun x y -> x > y);
               blater = (fun x y -> lift_comp .< .~(x.c) > .~(y.c) >.) }
  let gt_s x y = mk_binary gt_b x y

  let ge_b = { bnow = (fun x y -> x >= y);
               blater = (fun x y -> lift_comp .< .~(x.c) >= .~(y.c) >.) }
  let ge_s x y = mk_binary ge_b x y

  let max_b = { bnow = (fun x y -> max x y);
                blater = (fun x y -> lift_comp .< max .~(x.c) .~(y.c) >.) }
  let max_s x y = mk_binary max_b x y

  let min_b = { bnow = (fun x y -> min x y);
                blater = (fun x y -> lift_comp .< min .~(x.c) .~(y.c) >.) }
  let min_s x y = mk_binary min_b x y
end

module Float_Add_Monoid (* : MONOID *) =
struct
  include Float_Set
  let zero = 0.
  let plus_op = { bnow = (fun x y -> x +. y); 
                  blater = (fun x y -> lift_comp .<.~(x.c) +. .~(y.c)>.); }
  let plus_b = { bop = plus_op; uelem = zero }
  let plus_s a b = mk_monoid plus_b a b
end

module Float_Normed_Set (* : NORMED_SET *) =
struct
  include Float_Set
  module R = Float_Add_Monoid
  let norm_b = { unow = (fun x -> if x < 0. then -. x else x);
		 ulater = (fun x -> lift_comp (
		   if x.a then .< if .~(x.c) < 0. then -. .~(x.c) else .~(x.c) >.
		   else .< let x = .~(x.c) in if x < 0. then -. x else x >.)) }
  let norm_s x = mk_unary norm_b x
end

module Float_Add_Normed_Monoid (* : NORMED_MONOID *) =
struct
  include Float_Normed_Set
  let zero = 0.
  let plus_op = { bnow = (fun x y -> x +. y); 
                  blater = (fun x y -> lift_comp .<.~(x.c) +. .~(y.c)>.); }
  let plus_b = { bop = plus_op; uelem = zero }
  let plus_s a b = mk_monoid plus_b a b
end

module Float_Ring_Base =
struct
  include Float_Set
  let zero = 0.
  let one = 1.
  let negone = -1.
  let two = 2.

  let add_op = { bnow = (fun x y -> x+.y); 
                 blater = (fun x y -> lift_comp .<.~(x.c) +. .~(y.c)>.); }
  let add_b = { bop = add_op; uelem = 0. }
  let add_s a b = mk_monoid add_b a b

  let neg_b = { unow = (fun x -> -. x);
		ulater = (fun x -> lift_comp .< -. .~(x.c) >.) }
  let neg_s x = mk_unary neg_b x

  let sub_b =
    { bnow = (fun x y -> add_b.bop.bnow x (neg_b.unow y));
      blater = (fun x y -> lift_comp .< .~(x.c) -. .~(y.c) >.) }
  let sub_s a b = match a,b with
    | a, Now zero -> a
    | Now zero, b -> neg_s b
    (* | Later a, Later b when a = b -> Now zero *)
    | a, b -> mk_binary sub_b a b

  let mul_op = { bnow = (fun x y -> x*.y); 
                 blater = (fun x y -> lift_comp .<.~(x.c) *. .~(y.c)>.) }
  let mul_mon = { bop = mul_op; uelem = one }
  let mul_b = { monp = add_b; mont = mul_mon}
  let mul_s x y =
    if x = (Now negone) then neg_s y
    else if y = (Now negone) then neg_s x
    else mk_ring mul_b x y

  let abs_b = { unow = (fun x -> abs_float x);
                ulater = (fun x -> lift_comp .< abs_float .~(x.c) >.) }
  let abs_s x = mk_unary abs_b x

  let pow x y =
    let bnow x y = x ** y
    and blater x y = lift_comp .< .~(x.c) ** .~(y.c) >. in
    mk_binary { bnow = bnow; blater = blater } x y
  (* an optimized version with threshold 7 *)
  let rec _pow_ x =
    function | 0 -> Now 1.
             | i -> mul_s x (_pow_ x (i-1))
  let int_pow n x =
    match n with 0 -> Now 1. | 1 -> x | n ->
    (if (is_now x) then pow x (Now (of_int n))
    else if n > 7 then pow x (Now (of_int n))
    else (* we know it's a code expression *)
      let x = (to_code x) in
      (* let..in for better common expression *)
      Later (lift_comp .< let y = .~x in .~(to_code (_pow_ (of_atom .<y>.) n)) >.))

  (* NORMED_MONOID stuff *)
  let plus_b = add_b
  let plus_s = add_s
  let norm_b = abs_b
  let norm_s = abs_s
end

module Float_Sign_Inexact
    (E : sig val eps : float * ('a, float) code end) =
struct
  let eps = fst E.eps
  let eps_c = snd E.eps
  let sgn x =
    let sgn_now x = Sign.bind (Now (x > eps)) (Now (x >= -.eps && x <= eps))
      (Now (x < -.eps)) (Now (x >= -.eps)) (Now (x <= eps))
    and sgn_later x = Sign.bind (Later (lift_comp .< .~(x.c) > .~eps_c >.))
      (Later (lift_comp .< let x = .~(x.c) and eps = .~eps_c in
                 x >= -.eps &&  x <= eps >.))
      (Later (lift_comp .< .~(x.c) < -. .~eps_c >.))
      (Later (lift_comp .< .~(x.c) >= -. .~eps_c >.))
      (Later (lift_comp .< .~(x.c) <= .~eps_c >.)) in
    match x with
    | Now x -> sgn_now x
    | Later x -> sgn_later x
end

module Float_Sign_Exact =
struct
  let sgn x =
    let sgn_now x = Sign.bind (Now (x > 0.)) (Now (x = 0.))
      (Now (x < 0.)) (Now (x >= 0.)) (Now (x <= 0.))
    and sgn_later x = Sign.bind
      (Later (lift_comp .< .~(x.c) > 0. >.))
      (Later (lift_comp .< .~(x.c) = 0. >.))
      (Later (lift_comp .< .~(x.c) < 0. >.))
      (Later (lift_comp .< .~(x.c) >= 0. >.))
      (Later (lift_comp .< .~(x.c) <= 0. >.)) in
    match x with
    | Now x -> sgn_now x
    | Later x -> sgn_later x
end

module Float_Ring_Exact (* : RING *) =
struct
  include Float_Ring_Base
  include Float_Sign_Exact
end

module Float_Ring_Inexact
    (E : sig val eps : float * ('a, float) code end) (* : RING *) =
struct
  include Float_Ring_Base
  include Float_Sign_Inexact (E)
end

module Float_Field_Base =
struct
  include Float_Ring_Base
  let pi = 4.0 *. atan 1.0
  let inv_b = { unow = (fun x -> 1. /. x);
               ulater = (fun x -> lift_comp .< 1. /. .~(x.c) >.) }
  let inv_s a = mk_unary inv_b a
  let div_b = { bnow = (fun x y -> mul_b.monp.bop.bnow x (inv_b.unow y));
                blater = (fun x y -> lift_comp .< .~(x.c) /. .~(y.c) >.) }
  let div_s a b = mk_binary div_b a b
end

module Float_Field_Exact (* : FIELD *) =
struct
  include Float_Field_Base
  include Float_Sign_Exact
end

module Float_Field_Inexact
    (E : sig val eps : float * ('a, float) code end) (* : FIELD *) =
struct
  include Float_Field_Base
  include Float_Sign_Inexact (E)
end

module Float_Real_Base =
struct
  include Float_Field_Base
  (* like signature, we have to duplicate unless we got Functor crazy *)
  let bot = neg_infinity
  let top = infinity

  let compare_b =
    { bnow = (fun x y -> Pervasives.compare x y);
      blater = (fun x y -> lift_comp .< Pervasives.compare .~(x.c) .~(y.c) >.) }
  let compare_s x y = mk_binary compare_b x y

  let lt_b = { bnow = (fun x y -> x < y);
               blater = (fun x y -> lift_comp .< .~(x.c) < .~(y.c) >.) }
  let lt_s x y = mk_binary lt_b x y

  let le_b = { bnow = (fun x y -> x <= y);
               blater = (fun x y -> lift_comp .< .~(x.c) <= .~(y.c) >.) }
  let le_s x y = mk_binary le_b x y

  let gt_b = { bnow = (fun x y -> x > y);
               blater = (fun x y -> lift_comp .< .~(x.c) > .~(y.c) >.) }
  let gt_s x y = mk_binary gt_b x y

  let ge_b = { bnow = (fun x y -> x >= y);
               blater = (fun x y -> lift_comp .< .~(x.c) >= .~(y.c) >.) }
  let ge_s x y = mk_binary ge_b x y

  let max_b = { bnow = (fun x y -> max x y);
                blater = (fun x y -> lift_comp .< max .~(x.c) .~(y.c) >.) }
  let max_s x y = mk_binary max_b x y

  let min_b = { bnow = (fun x y -> min x y);
                blater = (fun x y -> lift_comp .< min .~(x.c) .~(y.c) >.) }
  let min_s x y = mk_binary min_b x y
  
  let sqrt_s x = let now x = sqrt x
    and later x = lift_comp .< sqrt .~(x.c) >. in
    mk_unary { unow = now; ulater = later } x
end

module Float_Real_Exact (* : REALFIELD *) =
struct
  include Float_Real_Base
  include Float_Sign_Exact
end

module Float_Real_Inexact
    (E : sig val eps : float * ('a, float) code end) (* : REALFIELD *) =
struct
  include Float_Real_Base
  include Float_Sign_Inexact (E)
end
