open Staged
open Algebra

(* @TODO separate the matrix container from the matrix operations
 *       so the container can change for sparse, dense, etc. *)

module type MATRIX =
sig
  type 'a n_s
  type 'a m
  val nrows : 'a m -> int
  val ncols : 'a m -> int
  val dim : 'a m -> int * int
  val square : 'a m -> bool
  val same_dim : 'a m -> 'a m -> bool
  val get : 'a m -> int -> int -> 'a n_s
  val row : 'a m -> int -> 'a m
  val col : 'a m -> int -> 'a m
  val create : int -> int -> (int -> int -> 'a n_s) -> 'a m
  val create_row : int -> (int -> 'a n_s) -> 'a m
  val create_col : int -> (int -> 'a n_s) -> 'a m
  (* casting *)
  val row_of_array : 'a n_s array -> 'a m
  val col_of_array : 'a n_s array -> 'a m
  val of_array : 'a n_s array array -> 'a m (* row major *)
  (* special matrices *)
  val zero : int -> int -> 'a m
  val diag : int -> (int -> 'a n_s) -> 'a m
  val id : int -> 'a m
  (* transformations *)
  val transpose : 'a m -> 'a m
  val haugment : 'a m -> 'a m -> 'a m
  val vaugment : 'a m -> 'a m -> 'a m
  val minor : 'a m -> int -> int -> 'a m
  (* iterators *)
  val map : ('a n_s -> 'a n_s) -> 'a m -> 'a m
  val map2 : ('a n_s -> 'a n_s -> 'a n_s) -> 'a m -> 'a m -> 'a m
  (* algebra *)
  val add : 'a m -> 'a m -> 'a m
  val sub : 'a m -> 'a m -> 'a m
  val mul : 'a m -> 'a m -> 'a m
  val sadd : 'a m -> 'a n_s -> 'a m
  val ssub : 'a m -> 'a n_s -> 'a m
  val smul : 'a m -> 'a n_s -> 'a m
  val sdiv : 'a m -> 'a n_s -> 'a m
end

(* Implementation *)

type ('a,'b,'c) triple_type = { a: 'a; b: 'b; c: 'c }
let triple_create a b c = { a = a; b = b; c = c }
let triple_fst t = t.a
let triple_snd t = t.b
let triple_thr t = t.c

(* Dangerously dynamic *)
module Matrix (N : FIELD) (* : MATRIX *) =
struct
  type 'a n_s = 'a N.ns
  type 'a mat_rep = 'a n_s array array
  (* 0: row, 1: col, 2: matrix *)
  type 'a m = (int, int, 'a mat_rep) triple_type
  (* _matrix, _row, _col: utils *)
  let _matrix m = triple_thr m
  let _row m i = (_matrix m).(i)
  let _col m i = Array.map (fun r -> r.(i)) (_matrix m)
  let nrows m = triple_fst m
  let ncols m = triple_snd m
  let dim m = (nrows m), (ncols m)
  let square m = (nrows m) = (ncols m)
  let same_dim a b =
    ((nrows a) = (nrows b)) &&
    ((ncols a) = (ncols b))
  let create r c f =
    triple_create r c (Array.init r (fun i ->
      Array.init c (fun j -> f i j)))
  let create_row c f = create 1 c (fun _ j -> f j)
  let create_col r f = create r 1 (fun i _ -> f i)
  let row_of_array a = create_row (Array.length a) (fun i -> a.(i))
  let col_of_array a = create_col (Array.length a) (fun i -> a.(i))
  let of_array a =  (* row major *)
    create (Array.length a) (Array.length a.(0)) (fun i j -> a.(i).(j))
  let get m i j = (_matrix m).(i).(j)
  let row m i = create_row (ncols m) (fun j -> (_row m i).(j))
  let col m j = create_col (nrows m) (fun i -> (_col m j).(i))
  (* let eq a b = (same_dim a b) &&
    Array.iter (fun r -> Array.iter (fun x -> ) r) a *)
  let zero r c = create r c (fun _ _ -> (Now N.zero))
  let diag n f = create n n (fun i j ->
      if i = j then f i else (Now N.zero))
  let id n = diag n (fun _ -> (Now N.one))
  let transpose m =
    create (ncols m) (nrows m) (fun i j -> get m j i)
  (* a | b *)
  let haugment a b =
    let ra, ca = nrows a, ncols a
    and rb, cb = nrows b, ncols b in
    assert (ra = rb);
    let aug = Array.mapi (fun i x ->
      Array.append x (_row b i)) (_matrix a) in
    triple_create ra (ca + cb) aug
  (*  a
   * --- = a' | b'
   *  b
   *)
  let vaugment a b = transpose (haugment (transpose a) (transpose b))
  let map f m =
    create (nrows m) (ncols m) (fun i j -> f (get m i j))
  let map2 f a b =
    if not (same_dim a b) then
      failwith "Matrix.map2: wrong dimesions"
    else
      create (nrows a) (ncols a) (fun i j -> f (get a i j) (get b i j))
  let add a b = map2 N.add_s a b
  let sub a b = map2 N.sub_s a b
  let _vector_mul a b =
    let n = Array.length a in
    let p = Array.init n (fun i -> N.mul_s a.(i) b.(i)) in
    Array.fold_left (fun acc x -> N.add_s acc x) (Now N.zero) p
  let mul a b =
    let f i j = _vector_mul (_row a i) (_col b j) in
    create (nrows a) (ncols b) f
  (* Apply f to each entry in m.
   * Generic operation for scalar mul, div, add etc. *)
  let sadd m s = map (fun x -> N.add_s s x) m
  let ssub m s = map (fun x -> N.sub_s s x) m
  let smul m s = map (fun x -> N.mul_s s x) m
  let sdiv m s = map (fun x -> N.div_s s x) m

  (* minor m i j = copy of m without row i and column j *)
  let minor m i j =
    let skip x th = if x < th then x else x+1 in
    let f x y = get m (skip x i) (skip y j) in
    create ((nrows m)-1) ((ncols m)-1) f
end
