open Staged
open Basetypes
open Algebra

module Rational_Set (* : SET *) =
struct
  type n = int * int
  type 'a ns = ('a, int * int) staged
  type 'a rel = 'a ns -> 'a ns -> 'a Bool.b
  type 'a bin = 'a ns -> 'a ns -> 'a ns
  let to_string_b = 
      { unow = (fun (x,y) -> (string_of_int x) ^ "/" ^ (string_of_int y));
        ulater = (fun x -> lift_comp
	    .< let x,y = .~(x.c) in
	        (string_of_int x) ^ "/" ^ (string_of_int y) >.); }
  let to_string_s v = mk_unary to_string_b v

  let eq_b = { bnow = (fun (x,y) (x',y') -> x=x' && y=y');
               blater = (fun a b -> lift_comp
	      .< let x,y = .~(a.c) and x',y' = .~(b.c) in
	          x = x' && y=y' >.); }
  let eq_s a b = mk_binary eq_b a b
  let neq_b = { bnow = (fun x y -> not (eq_b.bnow x y));
                blater = (fun a b -> lift_comp
	           .< let x,y = .~(a.c) and x',y' = .~(b.c) in
		       x <> x' && y <> y' >.); }
  let neq_s a b = mk_binary neq_b a b
  let eq_tol = eq_s
  let of_int x = (1, x)
end

module Rational_Ring (* : RING *) =
struct
  include Rational_Set
  let zero = (0, 1)
  let one = (1, 1)
  let negone = (-1, 1)
  let two = (2, 1)

  let add_op = { bnow = (fun (a,b) (c,d) -> ((a*d+b*c),(b*d))); 
                 blater = (fun x y -> lift_comp
		       .< let a,b = .~(x.c) and c,d = .~(y.c) in
			  (a*d+b*c),(b*d) >.); }
  let add_b = { bop = add_op; uelem = zero }
  let add_s a b = mk_monoid add_b a b

  let neg_b = { unow = (fun (x,y) -> (-x,y));
		ulater = (fun x -> lift_comp
		           .< let a,b = .~(x.c) in (-a,b) >.) }
  let neg_s x = mk_unary neg_b x

  let sub_b =
        { bnow = (fun (a,b) (c,d) -> ((a*d-b*c),(b*d))); 
          blater = (fun x y -> lift_comp
	    .< let a,b = .~(x.c) and c,d = .~(y.c) in
	        (a*d-b*c),(b*d) >.); }
  let sub_s a b = match a,b with
    | a, Now zero -> a
    | Now zero, b -> neg_s b
    (* | Later a, Later b when a = b -> Now zero *)
    | a, b -> mk_binary sub_b a b

  let mul_op = { bnow = (fun (a,b) (c,d) -> a*b,c*d); 
                 blater = (fun x y -> lift_comp .<
		      (fst .~(x.c))*(fst .~(y.c)) , (snd .~(x.c))*(snd .~(y.c)) >.); }
  let mul_mon = { bop = mul_op; uelem = one }
  let mul_b = { monp = add_b; mont = mul_mon}
  let mul_s x y =
    if x = (Now negone) then neg_s y
    else if y = (Now negone) then neg_s x
    else mk_ring mul_b x y

  let abs_int x = if x < 0 then -x else x
  let abs_b = { unow = (fun (a,b) -> (abs_int a),(abs_int b));
                ulater = (fun x -> lift_comp
		    .< let a,b = .~(x.c) in
		         let a' = if a < 0 then -a else a
			 and b' = if b < 0 then -b else b in
			 a', b' >.) }
  let abs_s x = mk_unary abs_b x

  let eps = 1e-6
  let eps_c = .<1e-6>.

  let posn (a,b) = a > 0 && b > 0
  let posc x = .< let a,b = .~(x.c) in a > 0 && b > 0 >.

  let zern (a,b) = a = 0
  let zerc x = .< let a,_ = .~(x.c) in a = 0 >.

  let negn (a,b) = a < 0 || b < 0
  let negc x = .< let a,b = .~(x.c) in a < 0 || b > 0 >.

  let pozn (a,b) = a >= 0 && b > 0
  let pozc x = .< let a,b = .~(x.c) in a >= 0 && b > 0 >.

  let nezn (a,b) = a <= 0 || b < 0
  let nezc x = .< let a,b = .~(x.c) in a <= 0 || b < 0 >.

  let sgn x =
    let sgn_now x = Sign.bind (Now (posn x)) (Now (zern x))
      (Now (negn x)) (Now (pozn x)) (Now (nezn x))

    and sgn_later x = Sign.bind (Later (lift_comp (posc x)))
	                        (Later (lift_comp (zerc x)))
	                        (Later (lift_comp (negc x)))
	                        (Later (lift_comp (pozc x)))
	                        (Later (lift_comp (nezc x)))
    in
    match x with
    | Now x -> sgn_now x
    | Later x -> sgn_later x
  let pow x y = failwith "not implemented"
  (* an optimized version with threshold 7 *)
  let int_pow n x = failwith "not implemented"
end

(*
module Float_Field (* : FIELD *) =
struct
  include Float_Ring
  let pi = 4.0 *. atan 1.0
  let inv_b = { unow = (fun x -> 1. /. x);
               ulater = (fun x -> lift_comp .< 1. /. .~(x.c) >.) }
  let inv_s a = mk_unary inv_b a
  let div_b = { bnow = (fun x y -> mul_b.mop.bnow x (inv_b.unow y));
                blater = (fun x y -> lift_comp .< .~(x.c) /. .~(y.c) >.) }
  let div_s a b = mk_binary div_b a b
end
*)
