{-# LANGUAGE RecordWildCards #-}
module Main (main) where

import GHC.Driver.Monad
import GHC
    ( runGhc, defaultErrorHandler
    , setTargets
    , parseModule, typecheckModule, desugarModule, coreModule
    , ParsedModule(..), TypecheckedModule(..)
    , setSessionDynFlags
    )

import GHC.Driver.Main (hscSimplify)
import GHC.Driver.Session
import GHC.Driver.Backend
import GHC.Unit.Module.Name
import GHC.Unit.State
import GHC.Unit.Types
import GHC.Utils.Error
import GHC.Driver.Make
import GHC.Driver.Env.Types
import GHC.Driver.Phases
import GHC.Unit.Module.Env
import GHC.Unit.Module.ModIface
import GHC.Unit.Module.ModDetails
import GHC.Unit.Module.ModSummary
import GHC.Unit.Module.Graph
import GHC.Unit.External
import GHC.Unit.Env
import GHC.Unit.Home
import GHC.Unit.Home.ModInfo
import GHC.Driver.Errors.Types
import GHC.Types.Target
import GHC.Types.SourceFile
import GHC.Types.SafeHaskell
import GHC.Types.Error
import GHC.Hs.Instances ()
import GHC.Plugins
import GHC.Tc.Types
import qualified Data.Map as M
import GHC.Iface.Make
import GHC.Types.TypeEnv
import GHC.Iface.Tidy (tidyProgram)
import GHC.CoreToStg.Prep

import GHC.Data.IOEnv
import GHC.Tc.Utils.Monad (updateEps_)
import GHC.Core.InstEnv
import GHC.Core.FamInstEnv


import Control.Monad (unless, (<=<))
import Data.Either (partitionEithers)
import System.Exit
import System.FilePath

import qualified GHC.Data.EnumSet as EnumSet
import qualified GHC.LanguageExtensions.Type as Ext

-- Set this to wherever you have the GHC HEAD libs built
libDir = "/home/mi/prog/ghc/_build/stage1/lib"

main :: IO ()
main = defaultErrorHandler defaultFatalMessager defaultFlushOut $ do
    runGhc (Just libDir) $ withSourceErrors $ do
        setup

        -- These three modules could come from any combination of
        -- units. For this example, we use three different ones.
        coreBinds <- fmap mconcat $ mapM (compileToCore <=< loadModule <=< uncurry prepareModule)
            [ (unitA, "Class")
            , (unitB, "Instance")
            , (unitC, "Use")
            ]

        dflags <- getDynFlags
        let ctx = initDefaultSDocContext dflags
        liftIO $ putStrLn . renderWithContext ctx . vcat . map ppr $ coreBinds

unitA, unitB, unitC :: UnitId
unitA = UnitId . fsLit $ "unitA"
unitB = UnitId . fsLit $ "unitB"
unitC = UnitId . fsLit $ "unitC"

withSourceErrors :: (GhcMonad m) => m a -> m a
withSourceErrors = handleSourceError $ \e -> do
    printException e
    liftIO $ exitWith $ ExitFailure 1

setup :: (GhcMonad m) => m ()
setup = do
    dflags <- getSessionDynFlags

    dflags <- return $ updOptLevel 1 dflags
    -- We want to set just the following, but alas, GHC #20500
    dflags <- return $ gopt_set dflags Opt_LateSpecialise
    dflags <- return $ gopt_set dflags Opt_Specialise
    dflags <- return $ gopt_set dflags Opt_SpecialiseAggressively
    dflags <- return $ gopt_set dflags Opt_CrossModuleSpecialise
    dflags <- return $ gopt_set dflags Opt_EnableRewriteRules
    dflags <- return $ gopt_set dflags Opt_SolveConstantDicts

    -- dflags <- return $ dopt_set dflags Opt_D_dump_spec

    dflags <- return $ gopt_set dflags Opt_NoTypeableBinds
    dflags <- return $ dflags
        { backend = NoBackend
        , mainModuleNameIs = noMainModule
        , packageDBFlags = [PackageDB $ PkgDbPath (libDir </> "package.conf.d"), ClearPackageDBs]
        }

    setSessionDynFlags dflags
    invalidateModSummaryCache

noMainModule :: ModuleName
noMainModule = mkModuleName "No/Main"

noPrelude :: HscEnv -> HscEnv
noPrelude env = env
    { hsc_dflags = let dflags = hsc_dflags env in dflags
        { extensionFlags = EnumSet.delete Ext.ImplicitPrelude $ extensionFlags dflags
        }
    }

loadModule :: (GhcMonad m) => ModSummary -> m TypecheckedModule
loadModule ms = do
    (tmod, iface, details) <- prepareSource ms
    registerModule iface details
    return tmod

compileToCore :: (GhcMonad m) => TypecheckedModule -> m CoreProgram
compileToCore tmod = do
    dmod <- desugarModule tmod
    let mguts = coreModule dmod

    env <- getSession
    mguts' <- liftIO $ hscSimplify env [] mguts
    (cg_guts, details) <- liftIO $ tidyProgram env mguts'

    let binds = cg_binds cg_guts
        tycons = cg_tycons cg_guts
        data_tycons = filter isDataTyCon tycons
        mod = cg_module cg_guts
        ms = modSummary tmod
        modLoc = ms_location ms
    binds <- liftIO $ corePrepPgm env mod modLoc binds data_tycons

    return binds

prepareModule :: (GhcMonad m) => UnitId -> String -> m ModSummary
prepareModule unitId modName = do
    let mod = mkModuleName modName
        target = resolve unitId mod
    setTargets [target]

    -- Anything already in the provided map should be left as-is
    providers <- do
        env <- getSession
        return $ moduleNameProvidersMap . hsc_units $ env
    let exclude = M.keys providers

    env <- getSession
    env <- return $ setHomeUnit unitId env
    env <- return $ noPrelude env
    setSession env

    (errs, mss) <- liftIO $ partitionEithers <$> downsweep env [] exclude False
    reportErrors $ fmap GhcDriverMessage <$> errs

    let menv = mkModuleEnv [(ms_mod . emsModSummary $ ms, ms) | ms <- mss]

    let Just ems = lookupModuleEnv menv $ mkModule unit mod
    return $ emsModSummary ems
  where
    unit = RealUnit . Definite $ unitId

reportErrors :: (GhcMonad m) => [ErrorMessages] -> m ()
reportErrors errs = do
    errs <- return $ unionManyMessages errs
    unless (isEmptyMessages errs) $ throwErrors errs

prepareSource
    :: (GhcMonad m)
    => ModSummary
    -> m (TypecheckedModule, ModIface, ModDetails)
prepareSource ms = do
    pmod <- parseModule ms
    tmod <- typecheckModule pmod

    env <- getSession
    let (tcg, details) = tm_internals_ tmod
    iface <- liftIO $ mkModIface env tcg details (modSummary tmod)
    return (tmod, iface, details)

resolve :: UnitId -> ModuleName -> Target
resolve unitId mod = mkTarget $ "input" </> path <.> "src"
  where
    path = moduleNameSlashes mod

    mkTarget filePath = Target
        { targetId = TargetFile filePath (Just $ Cpp HsSrcFile)
        , targetAllowObjCode = False
        , targetContents = Nothing
        , targetUnitId = unitId
        }

invalidateModSummaryCache :: (GhcMonad m) => m ()
invalidateModSummaryCache = modifySession $ \env -> env
    { hsc_mod_graph = invalidateMG (hsc_mod_graph env)
    }
  where
    invalidateMG = mapMG invalidateMS
    invalidateMS ms = ms{ ms_hie_date = Nothing } -- addUTCTime (-1) <$> ms_hie_date ms

setHomeUnit :: UnitId -> HscEnv -> HscEnv
setHomeUnit unitId env = env
    { hsc_unit_env = let ue = hsc_unit_env env in ue
        { ue_home_unit = Just $ DefiniteHomeUnit unitId Nothing
        }
    }

mkModIface :: HscEnv -> TcGblEnv -> ModDetails -> ModSummary -> IO ModIface
mkModIface hsc_env tcg mod_details ms = mkIfaceTc hsc_env Sf_Ignore mod_details ms tcg

modifyUnitState :: (UnitState -> UnitState) -> HscEnv -> HscEnv
modifyUnitState f env = env
    { hsc_unit_env = let ue = hsc_unit_env env in ue
        { ue_units = f (ue_units ue)
        }
    }

registerModule :: (GhcMonad m) => ModIface -> ModDetails -> m ()
registerModule iface details@ModDetails{..} = do
    env <- getSession
    liftIO $ runIOEnv (Env env '\0' () ()) $ updateEps_ extendEps
    modifySession $ extendHpt . addModule
  where
    mod_info = HomeModInfo iface details Nothing

    mod = mi_module iface
    modOrig = ModOrigin (Just True) [] [] True

    addModule = modifyUnitState $ \us -> us
        { moduleNameProvidersMap = M.insert (moduleName mod) (M.singleton mod modOrig) $ moduleNameProvidersMap us
        }

    extendHpt env = env
        { hsc_unit_env = let ue = hsc_unit_env env in ue
            { ue_hpt = hpt
            }
        }
      where
        hpt = addToHpt (hsc_HPT env) (moduleName mod) mod_info

    extendEps :: ExternalPackageState -> ExternalPackageState
    extendEps eps = eps
        { eps_PIT = extendModuleEnv (eps_PIT eps) mod iface
        , eps_PTE = plusTypeEnv (eps_PTE eps) md_types
        , eps_rule_base = extendRuleBaseList (eps_rule_base eps) md_rules
        , eps_inst_env = extendInstEnvList (eps_inst_env eps) md_insts
        , eps_fam_inst_env = extendFamInstEnvList (eps_fam_inst_env eps) md_fam_insts
        , eps_ann_env = extendAnnEnvList (eps_ann_env eps) md_anns
        , eps_mod_fam_inst_env =
                let fam_inst_env = extendFamInstEnvList emptyFamInstEnv md_fam_insts
                in extendModuleEnv (eps_mod_fam_inst_env eps) mod fam_inst_env
        , eps_stats = addEpsInStats (eps_stats eps) (length $ typeEnvElts md_types) (length md_insts) (length md_rules)
        }

-- TODO: this should just be a method exported from ParsedMod class by GHC
-- modSummary :: ParsedMod m => m -> ModSummary
modSummary :: TypecheckedModule -> ModSummary
modSummary = pm_mod_summary . tm_parsed_module
