open Staged
open Basetypes
open Util
open Vertex

(* d-simplex has (d+1)-vertices *)
module type SIMPLEX =
sig
  module V : VERTEX
  type 'a simplex_s
  val dim : 'a simplex_s -> int
  val of_vertices : 'a V.vertex_s list -> 'a simplex_s
  (* isomorphism modulo order of vertices *)
  val eq : 'a simplex_s -> 'a simplex_s -> 'a Bool.b
  (* vertices *)
  val vertex : 'a simplex_s -> int -> 'a V.vertex_s
  val vertices : 'a simplex_s -> 'a V.vertex_s list
  val is_vertex : 'a simplex_s -> 'a V.vertex_s -> 'a Bool.b
  val v_index : 'a simplex_s -> 'a V.vertex_s -> ('a, int) staged
  (* faces *)
  (* non-oriented face: nD *)
  val face : 'a simplex_s -> int -> 'a simplex_s
  (* Oriented face: 2D  & 3D only. Such that for a vertex v_i
     and opposite face f_i: orient f_i v_i = +ve *)
  val oface : 'a simplex_s -> int -> 'a simplex_s
  val faces : 'a simplex_s -> 'a simplex_s list
  val ofaces : 'a simplex_s -> 'a simplex_s list
  val is_face : 'a simplex_s -> 'a simplex_s -> 'a Bool.b
  val f_index : 'a simplex_s -> 'a simplex_s -> ('a, int) staged
  (* neighbors *)

  val is_neighbor : 'a simplex_s -> 'a simplex_s -> 'a Bool.b
  val n_index : 'a simplex_s -> 'a simplex_s -> ('a, int) staged
  (* val neighbor : 'a simplex_s -> int -> 'a simplex_s
  val neighbors : 'a simplex_s -> 'a simplex_s list *)
end

(* util:  [| 0 .. n-1 |] *)
let array_0n n = Array.init n (fun i -> i)

(* util: [ f 0 ; .. ; f (n-1) ] *)
let list_0n n f = Array.to_list (Array.init n f)

(* util: remove nth element from list a *)
let remove_l a nth = List.filter ((<>) (List.nth a nth)) a
let remove_a a nth = Array.of_list (remove_l (Array.to_list a) nth)

module N_Simplex
  (V : VERTEX)
  (* : SIMPLEX *) =
struct
  module V = V
  type 'a simplex_s = 'a V.vertex_s list
  let dim s = (List.length s)-1
  let of_vertices l = l

  let vertex s i = List.nth s i
  let vertices s = s
  let is_vertex s v = exists (fun x -> V.eq v x) s
  let v_index s v = findi (fun x -> V.eq v x) s

  (* isomorphism *)
  let eq a b =
    if (dim a) <> (dim b) then Now false
    else
      (* |ver(b)| = |ver(a)| && v(b) \subset v(a) *)
      let l = List.map (fun v -> is_vertex a v) b in
      List.fold_left (fun x y -> Bool.and_s x y) (Now true) l

  let oface_desc s i =
    if (dim s) = 2 then
      [| [| 1; 2 |]; [| 2; 0 |]; [| 0; 1 |] |]
    else if (dim s) = 3 then
      [| [| 2; 1; 3 |]; [| 0; 2; 3 |];
         [| 1; 0; 3 |]; [| 0; 1; 2 |] |]
    else failwith "dimensions higher than 3 are not implemented"

  let face_desc s i =
   (* simplex discriptor matrix
       m : n by n-1 matrix;
       n faces & n-1 vertices per face *)
    let m = array_0n ((dim s)+1) in
    let m = Array.map (fun i -> remove_a m i) m in m

  let _face s i m =
    (* extract i^th face *)
    let face_disc = m.(i) in
    (* get the vertices for that face *)
    let f = Array.map (fun i -> vertex s i) face_disc in
    (* the face is also a simplex *)
    of_vertices (Array.to_list f)

  (* oriented face: 2D  & 3D only *)
  let oface s i = _face s i (oface_desc s i)
  (* non-oriented face: nD *)
  let face s i = _face s i (face_desc s i)

  let faces s = list_0n ((dim s)+1) (fun i -> face s i)
  let ofaces s = list_0n ((dim s)+1) (fun i -> oface s i)
  let is_face s f = exists (fun x -> eq x f) (faces s)
  let f_index s f = findi (fun x -> eq x f) (faces s)
  let is_neighbor s n = exists (fun f -> is_face n f) (faces s)
  let n_index s n = findi (fun f -> is_face n f) (faces s)
  (* old implementation
  let is_neighbor s n =
    if (dim s) <> (dim n) then Now false
    else
      let fs = faces s and fn = faces n in
      let l = List.map (fun f -> is_face s f) (faces n) in
      List.fold_left Bool.or_s (Now false) l *)
end


(*
module Segment
  (V : VERTEX)
  : SIMPLEX =
struct
  module V = V
  type 'a simplex_s = 'a V.vertex_s * 'a V.vertex_s
  let dim _ = 2
  let of_vertices l = (List.nth l 0), (List.nth l 1)
  let vertex (a,b) i = function | 0 -> a | 1 -> b
                                | _ -> "failwith Segment index error"
  let vertices (a,b) = [a; b]
  let is_vertex (a,b) v = Bool.or_s (V.eq v a) (V.eq v b)
  let v_index (a,b) v = Code.ife_ (V.eq v a) (Now 0) (Now 1)

  let face i s = failwith "error Segment.face"
  let faces s = failwith "error Segment.faces"
  let is_face s f = failwith "error Segment.is_face"
  let f_index s f = failwith "error Segment.f_index"

  let neigbor s i = 
end

module Triangle
  (N : RING)
  (P : POINT with type number = N.n and type 'a number_s = 'a N.ns)
   (* : SIMPLEX *) =
struct
  type 'a number_s = 'a N.ns
  type 'a vertex_s = 'a P.point_s
  type 'a simplex_s = 'a vertex_s * 'a vertex_s * 'a vertex_s
  let dim = 3
  let from_vertices l = let a = Array.of_list l in
                    a.(0), a.(1), a.(2)
  let vertices (a,b,c) = [a; b; c]
  let face s i =
    let a = Array.of_list (vertices s) in
    List.map (fun i -> a.(i)) [| [1; 2]; [2; 0]; [0; 1] |].(i)
  let faces s = List.map (fun i -> face s i) [0; 1; 2]
end
*)

