\section{Wrap}

Wrap occurrences of some type into a constructor ---
useful for restricted wrappers.

%{{{ imports
\begin{code}
{-# LANGUAGE ScopedTypeVariables, NoMonomorphismRestriction #-}
-- module Wrap where

-- import Language.Haskell.Exts.Syntax

import GetModuleExports (moduleExports, Exports(..))
import MatchTypes (matchType, stripSigma, substTypes
                  , TyPos(..), tyVarPoss, matchesSubType)
import PprUtils (splitQual, pprDataCon, pprintE, pprDataConImport, pprHsVarNewQual)
import FileUtils (writeModuleSDoc)

-- package GHC
import GHC ( Id, idType, DataCon, dataConType, TyVar )
import Type  ( Type )
import TcType  ( tcSplitSigmaTy, tcSplitFunTys
               , tcSplitFunTy_maybe, tcSplitTyConApp_maybe)
import TyCon  (TyCon, tyConTyVars, tyConDataCons_maybe)
import DataCon (DataCon, dataConFullSig)
import HscTypes ()
import Module   ( ModuleName, mkModuleName, moduleNameString )
import Outputable

import System (getArgs)
import System.IO (hPutStrLn, stderr)

import Data.Maybe (mapMaybe, isJust)

-- import Data.Set (Set)
import qualified Data.Set as Set
import Data.Map (Map)
import qualified Data.Map as Map

import Control.Monad (liftM, MonadPlus(mplus, mzero), msum)
import Control.Arrow ( first, (***) )
import Control.Monad.State  (State, get, put, modify, evalState, runState
                            , StateT, runStateT)
import Control.Monad.Trans (lift)
\end{code}
%}}}

\begin{code}
main :: IO ()
main = do
  clArgs0 <- getArgs
  let (clArgs1, warn) = case clArgs0 of
        "-w" : ss -> (ss, True)
        ss        -> (ss, False)
  let (clArgs, target) = case clArgs1 of
        "-o" : trg : ss  -> (ss, Just trg)
        ss               -> (ss, Nothing)
  case clArgs of
    implModString :  wrapperStrings@(_ : _) -> do
      wrapImpl (mkModuleName implModString) (map splitQual wrapperStrings) target warn
    _ -> hPutStrLn stderr "usage: Wrap WrapModule WrapperDataCon ImplementationModule"
\end{code}

%{{{ findWrapper
\begin{code}
findWrapper :: (ModuleName, String) -> Exports -> Either SDoc ((ModuleName, String), (DataCon, Type))
findWrapper p@(wrapModName, wrapperString) wrapExports = let
    dataCons = Set.toList $ exportDataCons  wrapExports
    candidates  =   map (\ dc -> (showSDoc (ppr dc), dc)) dataCons
                ++  map (\ dc -> (showSDocUnqual (ppr dc), dc)) dataCons
                ++  map (\ dc -> (show dc, dc)) dataCons
  in case lookup wrapperString candidates of
    Nothing -> Left . text $ "Wrap: " ++ wrapperString
                     ++ " not found in " ++ moduleNameString wrapModName
                     ++ " DataCon exports:\n"
                     ++ unlines (map fst candidates)
    Just wrapperDataDon -> let
      (_wForalls, _wTheta, wTy) = tcSplitSigmaTy $ dataConType wrapperDataDon
      in case tcSplitFunTy_maybe wTy of
           Nothing -> Left $ ppr wrapperDataDon <+> text " not a function!"
           Just (argTy, _resTy) -> Right (p, (wrapperDataDon, argTy))
\end{code}
%}}}

%{{{ collectEithers
\begin{code}
collectEithers :: [Either SDoc a] -> ([SDoc], [a])
collectEithers = h id id
  where
    h f g [] = (f [], g [])
    h f g (Left pr : es) = h (f . (pr :)) g es
    h f g (Right p : es) = h f (g . (p :)) es
\end{code}
%}}}

%{{{ reportWrapper
\begin{code}
reportWrapper :: ((ModuleName, String), (DataCon, Type)) -> IO ()
reportWrapper (_, (wrapperDataDon, wTyL)) = do
      pprintE $ text "Wrapper found:" <+> pprDataCon wrapperDataDon
      pprintE $ text "Wrapped type: " <+> ppr wTyL
\end{code}
%}}}

%{{{ wrapImpl :: ModuleName -> [(ModuleName, String)] -> IO ()
\begin{code}
wrapImpl :: ModuleName -> [(ModuleName, String)] -> Maybe String -> Bool -> IO ()
wrapImpl implModName wrapperPs target warn = do
  implExports : wrapExportss  <- moduleExports $ implModName : map fst wrapperPs
  let (noWrapperPprs, wrappers) = collectEithers $ zipWith findWrapper wrapperPs wrapExportss
  if null wrappers
    then fail . showSDoc $ vcat noWrapperPprs
    else do
      if null noWrapperPprs then return ()
        else do
          hPutStrLn stderr "================================"
          pprintE $ vcat noWrapperPprs
      hPutStrLn stderr "================================"
      mapM_ reportWrapper wrappers
      hPutStrLn stderr "================================"
      
      let implIdents = Set.toList $ exportIds implExports

      let wrapImport ((wmod, _),(wdc,_)) = pprDataConImport wmod wdc

      let newModName = case target of
            Just trg -> trg
            Nothing -> moduleNameString implModName ++ '.' : (wrappers >>= snd . fst)

      let (idents, defs) = unzip $ mapMaybe (doIdent $ map snd wrappers) implIdents
      writeModuleSDoc "Wrap" ".hs" newModName . vcat
        $  (if warn
              then (:) (text "{-# OPTIONS_GHC -fwarn-missing-signatures #-}")
              else id
           )
        $  (:) (text "{-# LANGUAGE NoMonomorphismRestriction #-}")
        $  (:) (text "module" <+> text newModName)
        $  -- qualified export identifiers necessary to avoid |Prelude| name clashes
           (++) (zipWith  (\ c ident -> text ("  " ++ [c]) <+> pprHsVarNewQual newModName ident)
                          ('(' : repeat ',') idents)
        $  (:) (text "  ) where")
        $  (:) (text "")
        $  (:) (text "import qualified " <+> ppr implModName)
        $  (++) (map wrapImport wrappers)
        $  (:) (text "")
        $  defs
      return ()
\end{code}
%}}}

%{{{ doIdent :: [(DataCon, Type)] -> Id -> Maybe (SDoc, SDoc)
|doIdent| returns:
\begin{itemize}
\item The identifier of the wrapped function,
  which is currently the original identifier,
  to be used for the export list.

\item An |SDoc| containing the definition of the wrapped function.
\end{itemize}
\begin{code}
doIdent :: [(DataCon, Type)] -> Id -> Maybe (Id, SDoc)
doIdent wrappers ident = let
    (argTys, resTy) = tcSplitFunTys $ stripSigma $ idType ident
    rhsFunName = withPprStyle defaultUserStyle (pprHsVar ident)
    lhsFunName = pprHsVar ident

    wrappers' = map (first pprHsVar) wrappers

    argPs ::  [State [SDoc] ([HSPat SDoc SDoc], [HSPat SDoc SDoc])]
    (inArgs, argPs) = wrapArgTys [] wrappers' argTys

    resWPs = wrapDCs [] wrappers' resTy
    precPprHSP = precPprHSPat id id

    mkAuxEq :: SDoc  -> (HSPat SDoc SDoc, HSPat SDoc SDoc) -> SDoc
    mkAuxEq aux (withDC, raw) = 
        precPprApplys precPprHSP 0 aux [raw] <+> text "=" <+> precPprHSP 0 (unIrref withDC)
    pprRes' precPprInner = case resWPs of
      Nothing -> return $ precPprInner 0
      Just mps -> let
        aux = text "wrap" -- safe because of explicit and qualified imports.
        in do
          auxArgs <- get
          auxPairs <- mapM (\ mp -> put auxArgs >> mp) mps
          return $ case auxPairs of
            [(DCApp dc [PVar _v], PVar _v')] -> dc <+> precPprInner 9
            _  ->   text " let" $$ nest 6 (vcat $ map (mkAuxEq aux) auxPairs)
               $+$  nest 2 (text "in " <+> aux <+> precPprInner 9)
    runArgP mp = do
      (withDCs, raws) <- mp
      let precPprInner p = precPprApplys precPprHSP p rhsFunName $ map unIrref raws
      rhs <- pprRes' precPprInner
      return $ precPprApplys precPprHSP 0 lhsFunName (map unIrref withDCs) <+> text "=" <+> rhs
  in if inArgs || isJust resWPs
     then  Just  .    (,) ident
                 $    text ""
                 $+$  text "{-" <+> rhsFunName <+> text "::" <+> ppr (idType ident) <+> text "-}"
                 $+$  vcat (map (flip evalState args . runArgP) argPs)
     else  Nothing
\end{code}
%}}}

%{{{ \subsection{|HSExpr| and |HSPat|}
\subsection{|HSExpr| and |HSPat|}

%{{{ data HSExpr
\begin{code}
data HSExpr dc v
  = Var v
  | DC dc
  | Lambda [HSPat dc v] (HSExpr dc v)
  | Apply (HSExpr dc v) (HSExpr dc v)
\end{code}
%}}}

%{{{ precPprHSExpr
\begin{code}
precPprHSExpr :: (dc -> SDoc) -> (v -> SDoc) -> Int -> HSExpr dc v -> SDoc
precPprHSExpr _pprDC  pprVar _p  (Var v) = pprVar v
precPprHSExpr pprDC  _pprVar _p  (DC dc) = pprDC dc
precPprHSExpr pprDC   pprVar p  (Lambda pats e) = cparen (p > 1) $
  text "\\" <+> hsep (map (precPprHSPat pprDC pprVar 9) pats) <+> text "->" <+> precPprHSExpr pprDC pprVar 0 e
precPprHSExpr pprDC   pprVar p   (Apply f a) =
  precPprHSExpr pprDC pprVar 0 f <+> precPprHSExpr pprDC pprVar 9 a
\end{code}
%}}}

%{{{ data HSPat, unIrref
\begin{code}
data HSPat dc v
  = PVar v
  | Irref (HSPat dc v)
  | DCApp dc [HSPat dc v]

unIrref :: HSPat dc v -> HSPat dc v
unIrref p@(PVar _v) = p
unIrref (Irref p) = unIrref p
unIrref (DCApp dc ps) = DCApp dc $ map unIrref ps
\end{code}
%}}}

%{{{ exprFromHSPat
\begin{code}
exprFromHSPat :: HSPat dc v -> HSExpr dc v
exprFromHSPat (PVar v) = Var v
exprFromHSPat (Irref p) = exprFromHSPat p
exprFromHSPat (DCApp dc ps) = foldr (Apply . exprFromHSPat) (DC dc) ps
\end{code}
%}}}

%{{{ precPprHSPat, precPprApplys
\begin{code}
precPprHSPat :: (dc -> SDoc) -> (v -> SDoc) -> Int -> HSPat dc v -> SDoc
precPprHSPat _pprDC  pprVar _p  (PVar v) = pprVar v
precPprHSPat pprDC   pprVar _p  (Irref pat) = text "~" <> precPprHSPat pprDC pprVar 9 pat
precPprHSPat pprDC   pprVar p   (DCApp dc pats) =
  precPprApplys (precPprHSPat pprDC pprVar) p (pprDC dc) pats

precPprApplys ::  (Int -> a -> SDoc) -> Int -> SDoc -> [a] -> SDoc
precPprApplys _pprX  _p  f [] = f
precPprApplys pprX   p   f pats = cparen (p > 8)
   $ hsep (f : map (pprX 9) pats)
\end{code}
%}}}
%}}}

%{{{ \subsection{|wrapDCs|}
\subsection{|wrapDCs|}

\edcomm{WK}{The caller that inserted the failing |stoptype| should fail, too!}

\begin{code}
wrapDCs :: [Type] -> [(SDoc, Type)] -> Type -> Maybe [State [a] (HSPat SDoc a, HSPat SDoc a)]
wrapDCs stopTypes wps ty = case msum $ map (`matchType` ty) stopTypes of
  Just _ -> Nothing
  Nothing ->  msum (map (`matchDC` ty) wps)
             `mplus` wrapPattern' stopTypes wps ty
\end{code}

%{{{ matchDC
|matchDC| only tries a top-level match of the argument wrapper.
\begin{code}
matchDC :: (SDoc, Type) -> Type -> Maybe [State [a] (HSPat SDoc a, HSPat SDoc a)]
matchDC (wdc, wTyL) ty = do
  _subst <- matchType wTyL ty
  return . (:[]) $ do
    a <- nextArg
    return (DCApp wdc [PVar a], PVar a)
\end{code}
%}}}

%{{{ wrapPattern'
|wrapPattern'| is a make-shift wrapper for |wrapPattern|
until the control flow is worked out properly.

\begin{code}
wrapPattern' :: [Type] -> [(SDoc, Type)] -> Type -> Maybe [State [a] (HSPat SDoc a, HSPat SDoc a)]
wrapPattern' stopTypes wps ty = do
  (inDCs, dcAlts) <- wrapPattern stopTypes wps ty
  if inDCs then return dcAlts else mzero
\end{code}
%}}}

%{{{ wrapPattern
|wrapPattern| returns |Nothing| if 
|ty| is not a constructor type, or we cannot get at the constructors.
Otherwise, the Boolean is |True|
if a wrapper match is found in at least one constructor argument.
\edcomm{WK}{Returning |noWrap| is premature for the case of phantom types.
Should look into |argTys| for that.}

\begin{code}
wrapPattern :: [Type] -> [(SDoc, Type)] -> Type -> Maybe (Bool, [State [a] (HSPat SDoc a, HSPat SDoc a)])
wrapPattern stopTypes wps ty = do
  (tc, _argTys) <- tcSplitTyConApp_maybe ty
  dcs <- tyConDataCons_maybe tc
  let irref = case dcs of _ : _ : _ -> False; _ -> True
  let (inArgTys, dcAlts) = unzip $ map (wrapDataCon irref stopTypes wps ty) dcs
      inDCs = or inArgTys
  return . (,) inDCs $  if inDCs && not (any null dcAlts)
                        then concat dcAlts else noWrap
\end{code}
%}}}

%{{{ wrapDataCon
|wrapDataCon| returns |True| if a wrapper match is found in the arguments of |dc|.
If |False|, the second result is just a type-shift pattern pair.

\edcomm{WK}{The |matchesSubType| should catch
at least non-polymorphic recursions,
so that |stoptypes| should now be superfluous.}

\begin{code}
wrapDataCon :: Bool -> [Type] -> [(SDoc, Type)] -> Type -> DataCon -> (Bool, [State [a] (HSPat SDoc a, HSPat SDoc a)])
wrapDataCon irref stopTypes wps ty dc = let
  (_univTyVars, _exTyVars, _eqSpec, _theta1, _theta2, argTys, resTy) = dataConFullSig dc
  in if any (matchesSubType resTy) argTys -- recursive constructor
     then (False, [])
     else case matchType resTy ty of
    Nothing -> (False, []) -- not a constructor for the current type (e.g. in GADT)
    Just subst -> let
        dcApp = (if irref then Irref else id) . DCApp (pprHsVar dc)
        argTys' = substTypes subst argTys
        (inArgTys, apss) = wrapArgTys (ty : stopTypes) wps argTys'
      in (,) inArgTys $ map (liftM (dcApp *** dcApp)) apss
\end{code}
%}}}

%{{{ wrapArgTys
|wrapArgTys| returns |True| if a wrapper match is found in |argTys|.
If |False|, the second result is just a type-shift pattern pair,
constructed using |noWrap|.
\begin{code}
wrapArgTys :: [Type] -> [(SDoc, Type)] -> [Type] -> (Bool, [State [a] ([HSPat SDoc a], [HSPat SDoc a])])
wrapArgTys stopTypes wps argTys = let
    mws = map (wrapDCs stopTypes wps) argTys
    apss = map (maybe noWrap id) mws
  in (,) (any isJust mws) . map (liftM unzip . sequence) $ sequence apss
\end{code}
%}}}

\begin{code}
noWrap :: [State [a] (HSPat dc a, HSPat dc a)]
noWrap = [liftM PVar nextArg >>= \ a -> return (a, a)]
\end{code}

\begin{code}
nextArg :: State [a] a
nextArg = do
  (a : as) <- get
  put as
  return a
\end{code}

\begin{code}
args :: [SDoc]
args = map (text . (: [])) (['a' .. 'z'] ++ error "more than 26 arguments!")
\end{code}
%}}}

%{{{ \subsection{|mkFunctorEnv|}
\subsection{|mkFunctorEnv|}

%{{{ data FunctorSpec, type FunctorEnv
\begin{code}
data FunctorSpec = FunctorSpec TyCon Int
  deriving (Eq, Ord)

instance Outputable FunctorSpec where
  ppr (FunctorSpec tyCon k) = text "FunctorSpec" <+> ppr tyCon <+> int k

instance Show FunctorSpec where
  showsPrec p = showsPrecSDoc p . ppr
\end{code}

|FunctorEnv| is a cache remembering for each |FunctorSpec| tried so far
its results, wich in the case of success are
\begin{itemize}
\item the specific |fmap| function name as a |String|
\item the definitions for that |fmap| function as an |SDoc|
\end{itemize}

\begin{code}
type FunctorEnv = Map FunctorSpec (Maybe FMapImpl)

type FMapImpl = (String, SDoc)
\end{code}
%}}}

%{{{ mkFMap
|mkFMap| returns, if successful,
a set of |FunctorSpec| keys into the |FunctorEnv| state.
This set is implemented as a |Map| to facilitate |Map.intersection| with the final state.

\begin{code}
mkFMap :: FunctorSpec
  ->   StateT (Map FunctorSpec ())
      (State FunctorEnv)
      (Maybe String)
mkFMap fs@(FunctorSpec tc k) = let
    tctvs = tyConTyVars tc
  in if k >= length tctvs then return Nothing
  else let
    v = tctvs !! k
  in case tyConDataCons_maybe tc of
    Nothing -> return Nothing
    Just dcs -> let
        irref = case dcs of _ : _ : _ -> False; _ -> True
        fmapIdent = "fmap_" ++ showSDoc (ppr tc) ++ '_' : show k
        fmapFVar = text "ff"
        blackhole = error $ fmapIdent ++ ": in statu nascendi"
        mkEq (pat, rhs)  =    pr (Var fmapFVar `Apply` exprFromHSPat pat)
                         <+>  text "=" <+> pr rhs
          where pr = precPprHSExpr ppr id 0
      in do
        fsSet fs $ Just (fmapIdent, blackhole)
        mdcAlts <- liftM sequence $ mapM (mkFMapAlt fmapIdent fmapFVar v irref) dcs
        case mdcAlts of
          Nothing -> do
            fsSet fs Nothing
            put Map.empty
            return Nothing
          Just dcAlts -> do
            fsSet fs $ Just (fmapIdent, vcat $ map mkEq dcAlts)
            return $ Just fmapIdent

fsSet  :: FunctorSpec -> Maybe FMapImpl
       -> StateT (Map FunctorSpec ()) (State FunctorEnv) ()
fsSet fs = lift . modify . Map.insert fs
\end{code}
%}}}

%{{{ mkFMapAlt
\begin{code}
mkFMapAlt :: String -> SDoc -> TyVar -> Bool -> DataCon
  -> StateT  (Map FunctorSpec ())
    (State   FunctorEnv)
    (Maybe (HSPat DataCon SDoc, HSExpr DataCon SDoc))
mkFMapAlt fmapIdent fmapFVar v irref dc = let
    (_univTyVars, _exTyVars, _eqSpec, _theta1, _theta2, argTys, resTy) = dataConFullSig dc
    dcApp = (if irref then Irref else id) . DCApp dc
  in do
    m  :: Maybe [((HSPat DataCon SDoc, HSExpr DataCon SDoc), Map FunctorSpec ())]
       <- liftM sequence . sequence $ zipWith (wrapFMap text fmapFVar v) argTys args
    case fmap unzip m of
      Nothing -> return Nothing
      Just (pps, ss) -> do
        modify (Map.union $ Map.unions ss)
        return . Just . (dcApp *** foldr Apply (DC dc)) $ unzip pps
\end{code}
%}}}

%{{{ type MyVar
\begin{code}
type MyVar = SDoc
\end{code}
%}}}

%{{{ wrapFMap
\begin{code}
wrapFMap :: forall dc v . (String -> v) -> v -> TyVar -> Type -> v
  ->   StateT (Map FunctorSpec ())
      (State FunctorEnv)
      (Maybe ((HSPat dc v, HSExpr dc v), Map FunctorSpec ()))
wrapFMap mkv fmapFVar v ty arg = case tyVarPoss v ty of
  []    -> return $ Just ((PVar arg, Var arg), Map.empty)
  [[]]  -> return $ Just ((PVar arg, Apply (Var fmapFVar) (Var arg)), Map.empty)
  poss  -> let compose = parens . foldr1 (\ x y -> x <+> text "." <+> y)
    in do
      mfs  :: Maybe [v]
           <- liftM sequence $ mapM (getFMapPoss mkv fmapFVar) poss
      return $ case mfs of
        Nothing   ->  Nothing
        Just _fs  ->  Just ((PVar arg, Apply (Var fmapFVar) (Var arg)), Map.empty)
                      -- \edcomm{WK}{same as for |[[]]|!}
\end{code}
%}}}

%{{{ getFMapPoss
\begin{code}
getFMapPoss :: (String -> v) -> v -> [TyPos] -> StateT (Map FunctorSpec ()) (State FunctorEnv)  (Maybe v)
getFMapPoss mkv fmapF [] = return $ Just fmapF
getFMapPoss mkv fmapF (pos : poss) = do
  mfs <- getFMapPoss mkv fmapF poss
  case mfs of
    Nothing -> return Nothing
    Just p -> getFMapPos mkv p pos
\end{code}
%}}}

%{{{ getFMapPos
\begin{code}
getFMapPos :: (String -> v) -> v -> TyPos
  -> StateT (Map FunctorSpec ()) (State FunctorEnv)  (Maybe v)
getFMapPos mkv _fmapFVar (TyConAppPos tc k) = let fs = FunctorSpec tc k in do
  fe <- lift get
  case Map.lookup fs fe of
    Just fInfo -> return $ fmap (mkv . fst) fInfo
    Nothing -> liftM (fmap mkv) $ mkFMap fs
getFMapPos _mkv _fmapFVar (FunTyPos False) = error "getFMapPos (FunTyPos Left)"
getFMapPos _mkv _fmapFVar (FunTyPos True) =  error "getFMapPos (FunTyPos Right)"
getFMapPos _mkv _fmapFVar (AppTyPos False) =  error "getFMapPos (AppTyPos Left)"
getFMapPos _mkv _fmapFVar (AppTyPos True) =  error "getFMapPos (AppTyPos Right)"
\end{code}
%}}}

%{{{ mkFunctorEnv
\begin{code}
mkFunctorEnv :: [FunctorSpec] -> Map FunctorSpec FMapImpl
mkFunctorEnv fss  = let
    ((_m, s), fe) = flip runState Map.empty
            . flip runStateT Map.empty
            $ mapM mkFMap fss
    combine _fs (Just info) _ = info
    combine fs Nothing _ = error $ "mkFunctorEnv: no implementation for " ++ show fs
  in Map.intersectionWithKey combine fe s
\end{code}
%}}}

|InstEnv.lookupInstEnv :: (InstEnv, InstEnv) -> Class -> [Type] -> ([InstMatch], [Instance])|

\begin{spec}

\end{spec}

%}}}

%{{{ EMACS lv
% Local Variables:
% folded-file: t
% fold-internal-margins: 0
% eval: (fold-set-marks "%{{{ " "%}}}")
% eval: (fold-whole-buffer)
% end:
%}}}
