open Staged
open Basetypes
open Algebra
open Vector
open Tuple

module VectorStaged (N : REALFIELD) (T : TUPLE) (* : VECTOR *) =
struct
  module N = N
  type vector = N.n T.t
  type 'a vector_s = ('a, vector) staged
  type coordinate_enum = int
  let dim = T.dim
  let zero n = Now (T.init n (fun _ -> N.zero))
  let ones n = Now (T.init n (fun _ -> N.one))
  let to_code = Staged.to_code
  let to_expr = Staged.to_expr
  let of_expr = Staged.of_expr
  let of_immediate = Staged.of_immediate

  let coord v i =
    let unow t = T.proj_n t i
    and ulater t =
      let x = (T.proj_c t i) in
      if t.a then lift_atom x.c else lift_comp x.c in
    mk_unary { unow = unow; ulater = ulater } v
  let of_coords l =
    if (List.for_all Staged.is_now l)
      then Now (T.of_list_n (List.map Staged.to_immediate l))
      else Later (T.of_list_c (List.map Staged.to_expr l))

  let map f v =
    let unow t = T.map_n f t
    and ulater t = (T.map_c f t) in
    mk_unary { unow = unow; ulater = ulater } v

  let map2 f v v' =
    let bnow v v' = T.map2_n f v v'
    and blater v v' = (T.map2_c f v' v) in
    mk_binary { bnow = bnow; blater = blater } v v'

  let fold f z v =
    let unow v = T.fold_n f z v
    and ulater v = (T.fold_c f z v) in
    mk_unary { unow = unow; ulater = ulater } v

  let mapfold m f z v =
    let unow v = T.mapfold_n m f z v
    and ulater v = (T.mapfold_c m f z v) in
    mk_unary { unow = unow; ulater = ulater } v

  let map2fold m f z v v' =
    let bnow v v' = T.map2fold_n m f z v v'
    and blater v v' = (T.map2fold_c m f z v v') in
    mk_binary { bnow = bnow; blater = blater } v v'

  let eq a b = map2fold N.eq_s Bool.and_s true a b
  let neq a b = Bool.not_s (eq a b)

  let mirror v = map N.neg_s v
  (* vector [+-] vector -> vector *)
  let add v0 v1 = map2 N.add_s v0 v1
  let sub v0 v1 = map2 N.sub_s v0 v1
  (* vector [./] scalar -> vector *)
  let scale v s = Code.let_ s (fun s -> map (N.mul_s s) v)
  let shrink v s = Code.let_ s (fun s ->
    map (fun x -> N.div_s x s) v)
  (* dot = sum_i (c^0_i * c^1_i) *)
  let dot v0 v1 =
    if v0 = (ones T.dim) then fold N.add_s N.zero v1
    else if v1 = (ones T.dim) then fold N.add_s N.zero v0
    else map2fold N.mul_s N.add_s N.zero v0 v1
  let length2 v = Code.let_ v (fun v -> (dot v v))
  let length v = N.sqrt_s (length2 v)
  (* let direction v = shrink v (length v) *)
  let direction v = Code.let_ v (fun v -> shrink v (length v))
  (* subtract the position vectors *)
  let gcross l = failwith "Vectoren.gcross: Not Impl Yet"
  let bcross v0 v1 =
    (* failwith "Vectoren.bcross: Not Impl Yet" *)
    (* if dim = 2 then
      let a0 = coord v0 0 and a1 = coord v0 1
      and b0 = coord v1 0 and b1 = coord v1 1 in
      of_coords [N.sub_s (N.mul_s a0 b1) (N.mul_s a1 b0)] *)
    if dim = 3 then
      let a0 = coord v0 0 and a1 = coord v0 1
      and a2 = coord v0 2
      and b0 = coord v1 0 and b1 = coord v1 1
      and b2 = coord v1 2 in
      let c0 = N.sub_s (N.mul_s a1 b2) (N.mul_s a2 b1)
      and c1 = N.sub_s (N.mul_s a2 b0) (N.mul_s a0 b2)
      and c2 = N.sub_s (N.mul_s a0 b1) (N.mul_s a1 b0) in
      of_coords [ c0; c1; c2 ]
     else failwith "Vectoren.bcross: No impl yet for dim =/= 3"
end


