{-# LANGUAGE CPP #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
module Language.Haskell.TH.ExpandSyns(-- * Expand synonyms
                                      expandSyns
                                     ,expandSynsWith
                                     ,SynonymExpansionSettings
                                     ,noWarnTypeFamilies

                                      -- * Misc utilities
                                     ,substInType
                                     ,substInCon
                                     ,evades,evade) where

import Language.Haskell.TH.Datatype.TyVarBndr
import Language.Haskell.TH.ExpandSyns.SemigroupCompat as Sem
import Language.Haskell.TH hiding(cxt)
import qualified Data.Set as Set
import Data.Generics
import Data.Maybe
import Control.Monad
import Prelude

#if !(MIN_VERSION_base(4,8,0))
import Control.Applicative
#endif

-- For ghci
#ifndef MIN_VERSION_template_haskell
#define MIN_VERSION_template_haskell(X,Y,Z) 1
#endif

packagename :: String
packagename :: String
packagename = "th-expand-syns"

tyVarBndrSetName :: Name -> TyVarBndr_ flag -> TyVarBndr_ flag
tyVarBndrSetName :: Name -> TyVarBndr_ flag -> TyVarBndr_ flag
tyVarBndrSetName n :: Name
n = (Name -> Name) -> TyVarBndr_ flag -> TyVarBndr_ flag
forall flag. (Name -> Name) -> TyVarBndr_ flag -> TyVarBndr_ flag
mapTVName (Name -> Name -> Name
forall a b. a -> b -> a
const Name
n)

#if MIN_VERSION_template_haskell(2,10,0)
-- mapPred is not needed for template-haskell >= 2.10
#else
mapPred :: (Type -> Type) -> Pred -> Pred
mapPred f (ClassP n ts) = ClassP n (f <$> ts)
mapPred f (EqualP t1 t2) = EqualP (f t1) (f t2)
#endif

#if MIN_VERSION_template_haskell(2,10,0)
bindPred :: (Type -> Q Type) -> Pred -> Q Pred
bindPred :: (Type -> Q Type) -> Type -> Q Type
bindPred = (Type -> Q Type) -> Type -> Q Type
forall a. a -> a
id
#else
bindPred :: (Type -> Q Type) -> Pred -> Q Pred
bindPred f (ClassP n ts) = ClassP n <$> mapM f ts
bindPred f (EqualP t1 t2) = (EqualP <$> f t1) `ap` f t2
#endif

data SynonymExpansionSettings =
  SynonymExpansionSettings {
    SynonymExpansionSettings -> Bool
sesWarnTypeFamilies :: Bool
  }


instance Semigroup SynonymExpansionSettings where
  SynonymExpansionSettings w1 :: Bool
w1 <> :: SynonymExpansionSettings
-> SynonymExpansionSettings -> SynonymExpansionSettings
<> SynonymExpansionSettings w2 :: Bool
w2 =
    Bool -> SynonymExpansionSettings
SynonymExpansionSettings (Bool
w1 Bool -> Bool -> Bool
&& Bool
w2)

-- | Default settings ('mempty'):
--
-- * Warn if type families are encountered.
--
-- (The 'mappend' is currently rather useless; the monoid instance is intended for additional settings in the future).
instance Monoid SynonymExpansionSettings where
  mempty :: SynonymExpansionSettings
mempty =
    SynonymExpansionSettings :: Bool -> SynonymExpansionSettings
SynonymExpansionSettings {
      sesWarnTypeFamilies :: Bool
sesWarnTypeFamilies = Bool
True
    }

#if !MIN_VERSION_base(4,11,0)
-- starting with base-4.11, mappend definitions are redundant;
-- at some point `mappend` will be removed from `Monoid`
  mappend = (Sem.<>)
#endif


-- | Suppresses the warning that type families are unsupported.
noWarnTypeFamilies :: SynonymExpansionSettings
noWarnTypeFamilies :: SynonymExpansionSettings
noWarnTypeFamilies = SynonymExpansionSettings
forall a. Monoid a => a
mempty { sesWarnTypeFamilies :: Bool
sesWarnTypeFamilies = Bool
False }

warn ::  String -> Q ()
warn :: String -> Q ()
warn msg :: String
msg =
#if MIN_VERSION_template_haskell(2,8,0)
    String -> Q ()
reportWarning
#else
    report False
#endif
      (String
packagename String -> String -> String
forall a. [a] -> [a] -> [a]
++": WARNING: "String -> String -> String
forall a. [a] -> [a] -> [a]
++String
msg)




type SynInfo = ([Name],Type)

nameIsSyn :: SynonymExpansionSettings -> Name -> Q (Maybe SynInfo)
nameIsSyn :: SynonymExpansionSettings -> Name -> Q (Maybe SynInfo)
nameIsSyn settings :: SynonymExpansionSettings
settings n :: Name
n = do
  Info
i <- Name -> Q Info
reify Name
n
  case Info
i of
    ClassI {} -> Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
    ClassOpI {} -> Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
    TyConI d :: Dec
d -> SynonymExpansionSettings -> Dec -> Q (Maybe SynInfo)
decIsSyn SynonymExpansionSettings
settings Dec
d
#if MIN_VERSION_template_haskell(2,7,0)
    FamilyI d :: Dec
d _ -> SynonymExpansionSettings -> Dec -> Q (Maybe SynInfo)
decIsSyn SynonymExpansionSettings
settings Dec
d -- Called for warnings
#endif
    PrimTyConI {} -> Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
    DataConI {} -> Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
    VarI {} -> Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
    TyVarI {} -> Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
#if MIN_VERSION_template_haskell(2,12,0)
    PatSynI {} -> Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
#endif

  where
    no :: m (Maybe a)
no = Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing

decIsSyn :: SynonymExpansionSettings -> Dec -> Q (Maybe SynInfo)
decIsSyn :: SynonymExpansionSettings -> Dec -> Q (Maybe SynInfo)
decIsSyn settings :: SynonymExpansionSettings
settings = Dec -> Q (Maybe SynInfo)
go
  where
    go :: Dec -> Q (Maybe SynInfo)
go (TySynD _ vars :: [TyVarBndr_ flag]
vars t :: Type
t) = Maybe SynInfo -> Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (SynInfo -> Maybe SynInfo
forall a. a -> Maybe a
Just (TyVarBndr_ flag -> Name
forall flag. TyVarBndr_ flag -> Name
tvName (TyVarBndr_ flag -> Name) -> [TyVarBndr_ flag] -> [Name]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TyVarBndr_ flag]
vars,Type
t))

#if MIN_VERSION_template_haskell(2,11,0)
    go (OpenTypeFamilyD (TypeFamilyHead name :: Name
name _ _ _)) = SynonymExpansionSettings -> Name -> Q ()
maybeWarnTypeFamily SynonymExpansionSettings
settings Name
name Q () -> Q (Maybe SynInfo) -> Q (Maybe SynInfo)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
    go (ClosedTypeFamilyD (TypeFamilyHead name :: Name
name _ _ _) _) = SynonymExpansionSettings -> Name -> Q ()
maybeWarnTypeFamily SynonymExpansionSettings
settings Name
name Q () -> Q (Maybe SynInfo) -> Q (Maybe SynInfo)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
#else

#if MIN_VERSION_template_haskell(2,9,0)
    go (ClosedTypeFamilyD name _ _ _) = maybeWarnTypeFamily settings name >> no
#endif

    go (FamilyD TypeFam name _ _) = maybeWarnTypeFamily settings name >> no
#endif

    go (FunD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
    go (ValD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
    go (DataD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
    go (NewtypeD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
    go (ClassD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
    go (InstanceD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
    go (SigD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
    go (ForeignD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no

#if MIN_VERSION_template_haskell(2,8,0)
    go (InfixD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
#endif

    go (PragmaD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no

    -- Nothing to expand for data families, so no warning
#if MIN_VERSION_template_haskell(2,11,0)
    go (DataFamilyD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
#else
    go (FamilyD DataFam _ _ _) = no
#endif

    go (DataInstD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
    go (NewtypeInstD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
    go (TySynInstD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no

#if MIN_VERSION_template_haskell(2,9,0)
    go (RoleAnnotD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
#endif

#if MIN_VERSION_template_haskell(2,10,0)
    go (StandaloneDerivD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
    go (DefaultSigD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
#endif

#if MIN_VERSION_template_haskell(2,12,0)
    go (PatSynD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
    go (PatSynSigD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
#endif

#if MIN_VERSION_template_haskell(2,15,0)
    go (ImplicitParamBindD {}) = Q (Maybe SynInfo)
forall (m :: * -> *) a. Monad m => m (Maybe a)
no
#endif

#if MIN_VERSION_template_haskell(2,16,0)
    go (KiSigD {}) = no
#endif

    no :: m (Maybe a)
no = Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing

maybeWarnTypeFamily :: SynonymExpansionSettings -> Name -> Q ()
maybeWarnTypeFamily :: SynonymExpansionSettings -> Name -> Q ()
maybeWarnTypeFamily settings :: SynonymExpansionSettings
settings name :: Name
name =
  Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SynonymExpansionSettings -> Bool
sesWarnTypeFamilies SynonymExpansionSettings
settings) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
      String -> Q ()
warn ("Type synonym families (and associated type synonyms) are currently not supported (they won't be expanded). Name of unsupported family: "String -> String -> String
forall a. [a] -> [a] -> [a]
++Name -> String
forall a. Show a => a -> String
show Name
name)







-- | Calls 'expandSynsWith' with the default settings.
expandSyns :: Type -> Q Type
expandSyns :: Type -> Q Type
expandSyns = SynonymExpansionSettings -> Type -> Q Type
expandSynsWith SynonymExpansionSettings
forall a. Monoid a => a
mempty


-- | Expands all type synonyms in the given type. Type families currently won't be expanded (but will be passed through).
expandSynsWith :: SynonymExpansionSettings -> Type -> Q Type
expandSynsWith :: SynonymExpansionSettings -> Type -> Q Type
expandSynsWith settings :: SynonymExpansionSettings
settings = Type -> Q Type
expandSyns'

    where
      expandSyns' :: Type -> Q Type
expandSyns' t :: Type
t =
         do
           (acc :: [TypeArg]
acc,t' :: Type
t') <- [TypeArg] -> Type -> Q ([TypeArg], Type)
go [] Type
t
           Type -> Q Type
forall (m :: * -> *) a. Monad m => a -> m a
return ((Type -> TypeArg -> Type) -> Type -> [TypeArg] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> TypeArg -> Type
applyTypeArg Type
t' [TypeArg]
acc)

      expandKindSyns' :: Type -> Q Type
expandKindSyns' k :: Type
k =
#if MIN_VERSION_template_haskell(2,8,0)
         do
           (acc :: [TypeArg]
acc,k' :: Type
k') <- [TypeArg] -> Type -> Q ([TypeArg], Type)
go [] Type
k
           Type -> Q Type
forall (m :: * -> *) a. Monad m => a -> m a
return ((Type -> TypeArg -> Type) -> Type -> [TypeArg] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> TypeArg -> Type
applyTypeArg Type
k' [TypeArg]
acc)
#else
         return k -- No kind variables on old versions of GHC
#endif

      applyTypeArg :: Type -> TypeArg -> Type
      applyTypeArg :: Type -> TypeArg -> Type
applyTypeArg f :: Type
f (TANormal x :: Type
x) = Type
f Type -> Type -> Type
`AppT` Type
x
      applyTypeArg f :: Type
f (TyArg _x :: Type
_x)   =
#if __GLASGOW_HASKELL__ >= 807
                                    Type
f Type -> Type -> Type
`AppKindT` Type
_x
#else
                                    -- VKA isn't supported, so
                                    -- conservatively drop the argument
                                    f
#endif


      -- Filter the normal type arguments from a list of TypeArgs.
      filterTANormals :: [TypeArg] -> [Type]
      filterTANormals :: [TypeArg] -> [Type]
filterTANormals = (TypeArg -> Maybe Type) -> [TypeArg] -> [Type]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe TypeArg -> Maybe Type
getTANormal
        where
          getTANormal :: TypeArg -> Maybe Type
          getTANormal :: TypeArg -> Maybe Type
getTANormal (TANormal t :: Type
t) = Type -> Maybe Type
forall a. a -> Maybe a
Just Type
t
          getTANormal (TyArg {})   = Maybe Type
forall a. Maybe a
Nothing

      -- Must only be called on an `x' requiring no expansion
      passThrough :: a -> b -> m (a, b)
passThrough acc :: a
acc x :: b
x = (a, b) -> m (a, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
acc, b
x)

      forallAppError :: [TypeArg] -> Type -> Q a
      forallAppError :: [TypeArg] -> Type -> Q a
forallAppError acc :: [TypeArg]
acc x :: Type
x =
          String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String
packagenameString -> String -> String
forall a. [a] -> [a] -> [a]
++": Unexpected application of the local quantification: "
                String -> String -> String
forall a. [a] -> [a] -> [a]
++Type -> String
forall a. Show a => a -> String
show Type
x
                String -> String -> String
forall a. [a] -> [a] -> [a]
++"\n    (to the arguments "String -> String -> String
forall a. [a] -> [a] -> [a]
++[TypeArg] -> String
forall a. Show a => a -> String
show [TypeArg]
accString -> String -> String
forall a. [a] -> [a] -> [a]
++")")

      -- If @go args t = (args', t')@,
      --
      -- Precondition:
      --  All elements of `args' are expanded.
      -- Postcondition:
      --  All elements of `args'' and `t'' are expanded.
      --  `t' applied to `args' equals `t'' applied to `args'' (up to expansion, of course)

      go :: [TypeArg] -> Type -> Q ([TypeArg], Type)

      go :: [TypeArg] -> Type -> Q ([TypeArg], Type)
go acc :: [TypeArg]
acc x :: Type
x@Type
ListT = [TypeArg] -> Type -> Q ([TypeArg], Type)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
passThrough [TypeArg]
acc Type
x
      go acc :: [TypeArg]
acc x :: Type
x@Type
ArrowT = [TypeArg] -> Type -> Q ([TypeArg], Type)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
passThrough [TypeArg]
acc Type
x
      go acc :: [TypeArg]
acc x :: Type
x@(TupleT _) = [TypeArg] -> Type -> Q ([TypeArg], Type)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
passThrough [TypeArg]
acc Type
x
      go acc :: [TypeArg]
acc x :: Type
x@(VarT _) = [TypeArg] -> Type -> Q ([TypeArg], Type)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
passThrough [TypeArg]
acc Type
x

      go [] (ForallT ns :: [TyVarBndr_ flag]
ns cxt :: [Type]
cxt t :: Type
t) = do
        [Type]
cxt' <- (Type -> Q Type) -> [Type] -> Q [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Type -> Q Type) -> Type -> Q Type
bindPred Type -> Q Type
expandSyns') [Type]
cxt
        Type
t' <- Type -> Q Type
expandSyns' Type
t
        ([TypeArg], Type) -> Q ([TypeArg], Type)
forall (m :: * -> *) a. Monad m => a -> m a
return ([], [TyVarBndr_ flag] -> [Type] -> Type -> Type
ForallT [TyVarBndr_ flag]
ns [Type]
cxt' Type
t')

      go acc :: [TypeArg]
acc x :: Type
x@ForallT{} = [TypeArg] -> Type -> Q ([TypeArg], Type)
forall a. [TypeArg] -> Type -> Q a
forallAppError [TypeArg]
acc Type
x

      go acc :: [TypeArg]
acc (AppT t1 :: Type
t1 t2 :: Type
t2) =
          do
            Type
r <- Type -> Q Type
expandSyns' Type
t2
            [TypeArg] -> Type -> Q ([TypeArg], Type)
go (Type -> TypeArg
TANormal Type
rTypeArg -> [TypeArg] -> [TypeArg]
forall a. a -> [a] -> [a]
:[TypeArg]
acc) Type
t1

      go acc :: [TypeArg]
acc x :: Type
x@(ConT n :: Name
n) =
          do
            Maybe SynInfo
i <- SynonymExpansionSettings -> Name -> Q (Maybe SynInfo)
nameIsSyn SynonymExpansionSettings
settings Name
n
            case Maybe SynInfo
i of
              Nothing -> ([TypeArg], Type) -> Q ([TypeArg], Type)
forall (m :: * -> *) a. Monad m => a -> m a
return ([TypeArg]
acc, Type
x)
              Just (vars :: [Name]
vars,body :: Type
body) ->
                  if [TypeArg] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeArg]
acc Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< [Name] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Name]
vars
                  then String -> Q ([TypeArg], Type)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String
packagenameString -> String -> String
forall a. [a] -> [a] -> [a]
++": expandSynsWith: Underapplied type synonym: "String -> String -> String
forall a. [a] -> [a] -> [a]
++(Name, [TypeArg]) -> String
forall a. Show a => a -> String
show(Name
n,[TypeArg]
acc))
                  else
                      let
                          substs :: [(Name, Type)]
substs = [Name] -> [Type] -> [(Name, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
vars ([TypeArg] -> [Type]
filterTANormals [TypeArg]
acc)
                          expanded :: Type
expanded = [(Name, Type)] -> Type -> Type
forall a. SubstTypeVariable a => [(Name, Type)] -> a -> a
doSubsts [(Name, Type)]
substs Type
body
                      in
                        [TypeArg] -> Type -> Q ([TypeArg], Type)
go (Int -> [TypeArg] -> [TypeArg]
forall a. Int -> [a] -> [a]
drop ([Name] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Name]
vars) [TypeArg]
acc) Type
expanded


      go acc :: [TypeArg]
acc (SigT t :: Type
t kind :: Type
kind) =
          do
            (acc' :: [TypeArg]
acc',t' :: Type
t') <- [TypeArg] -> Type -> Q ([TypeArg], Type)
go [TypeArg]
acc Type
t
            Type
kind' <- Type -> Q Type
expandKindSyns' Type
kind
            ([TypeArg], Type) -> Q ([TypeArg], Type)
forall (m :: * -> *) a. Monad m => a -> m a
return ([TypeArg]
acc', Type -> Type -> Type
SigT Type
t' Type
kind')

#if MIN_VERSION_template_haskell(2,6,0)
      go acc :: [TypeArg]
acc x :: Type
x@(UnboxedTupleT _) = [TypeArg] -> Type -> Q ([TypeArg], Type)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
passThrough [TypeArg]
acc Type
x
#endif

#if MIN_VERSION_template_haskell(2,8,0)
      go acc :: [TypeArg]
acc x :: Type
x@(PromotedT _) = [TypeArg] -> Type -> Q ([TypeArg], Type)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
passThrough [TypeArg]
acc Type
x
      go acc :: [TypeArg]
acc x :: Type
x@(PromotedTupleT _) = [TypeArg] -> Type -> Q ([TypeArg], Type)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
passThrough [TypeArg]
acc Type
x
      go acc :: [TypeArg]
acc x :: Type
x@Type
PromotedConsT = [TypeArg] -> Type -> Q ([TypeArg], Type)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
passThrough [TypeArg]
acc Type
x
      go acc :: [TypeArg]
acc x :: Type
x@Type
PromotedNilT = [TypeArg] -> Type -> Q ([TypeArg], Type)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
passThrough [TypeArg]
acc Type
x
      go acc :: [TypeArg]
acc x :: Type
x@Type
StarT = [TypeArg] -> Type -> Q ([TypeArg], Type)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
passThrough [TypeArg]
acc Type
x
      go acc :: [TypeArg]
acc x :: Type
x@Type
ConstraintT = [TypeArg] -> Type -> Q ([TypeArg], Type)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
passThrough [TypeArg]
acc Type
x
      go acc :: [TypeArg]
acc x :: Type
x@(LitT _) = [TypeArg] -> Type -> Q ([TypeArg], Type)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
passThrough [TypeArg]
acc Type
x
#endif

#if MIN_VERSION_template_haskell(2,10,0)
      go acc :: [TypeArg]
acc x :: Type
x@Type
EqualityT = [TypeArg] -> Type -> Q ([TypeArg], Type)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
passThrough [TypeArg]
acc Type
x
#endif

#if MIN_VERSION_template_haskell(2,11,0)
      go acc :: [TypeArg]
acc (InfixT t1 :: Type
t1 nm :: Name
nm t2 :: Type
t2) =
          do
            Type
t1' <- Type -> Q Type
expandSyns' Type
t1
            Type
t2' <- Type -> Q Type
expandSyns' Type
t2
            ([TypeArg], Type) -> Q ([TypeArg], Type)
forall (m :: * -> *) a. Monad m => a -> m a
return ([TypeArg]
acc,Type -> Name -> Type -> Type
InfixT Type
t1' Name
nm Type
t2')
      go acc :: [TypeArg]
acc (UInfixT t1 :: Type
t1 nm :: Name
nm t2 :: Type
t2) =
          do
            Type
t1' <- Type -> Q Type
expandSyns' Type
t1
            Type
t2' <- Type -> Q Type
expandSyns' Type
t2
            ([TypeArg], Type) -> Q ([TypeArg], Type)
forall (m :: * -> *) a. Monad m => a -> m a
return ([TypeArg]
acc,Type -> Name -> Type -> Type
UInfixT Type
t1' Name
nm Type
t2')
      go acc :: [TypeArg]
acc (ParensT t :: Type
t) =
          do
            (acc' :: [TypeArg]
acc',t' :: Type
t') <- [TypeArg] -> Type -> Q ([TypeArg], Type)
go [TypeArg]
acc Type
t
            ([TypeArg], Type) -> Q ([TypeArg], Type)
forall (m :: * -> *) a. Monad m => a -> m a
return ([TypeArg]
acc',Type -> Type
ParensT Type
t')
      go acc :: [TypeArg]
acc x :: Type
x@Type
WildCardT = [TypeArg] -> Type -> Q ([TypeArg], Type)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
passThrough [TypeArg]
acc Type
x
#endif

#if MIN_VERSION_template_haskell(2,12,0)
      go acc :: [TypeArg]
acc x :: Type
x@(UnboxedSumT _) = [TypeArg] -> Type -> Q ([TypeArg], Type)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
passThrough [TypeArg]
acc Type
x
#endif

#if MIN_VERSION_template_haskell(2,15,0)
      go acc :: [TypeArg]
acc (AppKindT t :: Type
t k :: Type
k) =
          do
            Type
k' <- Type -> Q Type
expandKindSyns' Type
k
            [TypeArg] -> Type -> Q ([TypeArg], Type)
go (Type -> TypeArg
TyArg Type
k'TypeArg -> [TypeArg] -> [TypeArg]
forall a. a -> [a] -> [a]
:[TypeArg]
acc) Type
t
      go acc :: [TypeArg]
acc (ImplicitParamT n :: String
n t :: Type
t) =
          do
            (acc' :: [TypeArg]
acc',t' :: Type
t') <- [TypeArg] -> Type -> Q ([TypeArg], Type)
go [TypeArg]
acc Type
t
            ([TypeArg], Type) -> Q ([TypeArg], Type)
forall (m :: * -> *) a. Monad m => a -> m a
return ([TypeArg]
acc',String -> Type -> Type
ImplicitParamT String
n Type
t')
#endif

#if MIN_VERSION_template_haskell(2,16,0)
      go [] (ForallVisT ns t) = do
        t' <- expandSyns' t
        return ([], ForallVisT ns t')

      go acc x@ForallVisT{} = forallAppError acc x
#endif

#if MIN_VERSION_template_haskell(2,17,0)
      go acc x@MulArrowT = passThrough acc x
#endif

-- | An argument to a type, either a normal type ('TANormal') or a visible
-- kind application ('TyArg').
data TypeArg
  = TANormal Type -- Normal arguments
  | TyArg    Kind -- Visible kind applications
  deriving Int -> TypeArg -> String -> String
[TypeArg] -> String -> String
TypeArg -> String
(Int -> TypeArg -> String -> String)
-> (TypeArg -> String)
-> ([TypeArg] -> String -> String)
-> Show TypeArg
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [TypeArg] -> String -> String
$cshowList :: [TypeArg] -> String -> String
show :: TypeArg -> String
$cshow :: TypeArg -> String
showsPrec :: Int -> TypeArg -> String -> String
$cshowsPrec :: Int -> TypeArg -> String -> String
Show

class SubstTypeVariable a where
    -- | Capture-free substitution
    subst :: (Name, Type) -> a -> a



instance SubstTypeVariable Type where
  subst :: (Name, Type) -> Type -> Type
subst vt :: (Name, Type)
vt@(v :: Name
v, t :: Type
t) = Type -> Type
go
    where
      go :: Type -> Type
go (AppT x :: Type
x y :: Type
y) = Type -> Type -> Type
AppT (Type -> Type
go Type
x) (Type -> Type
go Type
y)
      go s :: Type
s@(ConT _) = Type
s
      go s :: Type
s@(VarT w :: Name
w) | Name
v Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
w = Type
t
                    | Bool
otherwise = Type
s
      go ArrowT = Type
ArrowT
      go ListT = Type
ListT
      go (ForallT vars :: [TyVarBndr_ flag]
vars cxt :: [Type]
cxt body :: Type
body) =
          (Name, Type)
-> [TyVarBndr_ flag]
-> ([(Name, Type)] -> [TyVarBndr_ flag] -> Type)
-> Type
forall flag a.
(Name, Type)
-> [TyVarBndr_ flag]
-> ([(Name, Type)] -> [TyVarBndr_ flag] -> a)
-> a
commonForallCase (Name, Type)
vt [TyVarBndr_ flag]
vars (([(Name, Type)] -> [TyVarBndr_ flag] -> Type) -> Type)
-> ([(Name, Type)] -> [TyVarBndr_ flag] -> Type) -> Type
forall a b. (a -> b) -> a -> b
$ \vts' :: [(Name, Type)]
vts' vars' :: [TyVarBndr_ flag]
vars' ->
          [TyVarBndr_ flag] -> [Type] -> Type -> Type
ForallT [TyVarBndr_ flag]
vars' ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map ([(Name, Type)] -> Type -> Type
forall a. SubstTypeVariable a => [(Name, Type)] -> a -> a
doSubsts [(Name, Type)]
vts') [Type]
cxt) ([(Name, Type)] -> Type -> Type
forall a. SubstTypeVariable a => [(Name, Type)] -> a -> a
doSubsts [(Name, Type)]
vts' Type
body)

      go s :: Type
s@(TupleT _) = Type
s

      go (SigT t1 :: Type
t1 kind :: Type
kind) = Type -> Type -> Type
SigT (Type -> Type
go Type
t1) ((Name, Type) -> Type -> Type
forall a. SubstTypeVariable a => (Name, Type) -> a -> a
subst (Name, Type)
vt Type
kind)

#if MIN_VERSION_template_haskell(2,6,0)
      go s :: Type
s@(UnboxedTupleT _) = Type
s
#endif

#if MIN_VERSION_template_haskell(2,8,0)
      go s :: Type
s@(PromotedT _) = Type
s
      go s :: Type
s@(PromotedTupleT _) = Type
s
      go s :: Type
s@Type
PromotedConsT = Type
s
      go s :: Type
s@Type
PromotedNilT = Type
s
      go s :: Type
s@Type
StarT = Type
s
      go s :: Type
s@Type
ConstraintT = Type
s
      go s :: Type
s@(LitT _) = Type
s
#endif

#if MIN_VERSION_template_haskell(2,10,0)
      go s :: Type
s@Type
EqualityT = Type
s
#endif

#if MIN_VERSION_template_haskell(2,11,0)
      go (InfixT t1 :: Type
t1 nm :: Name
nm t2 :: Type
t2) = Type -> Name -> Type -> Type
InfixT (Type -> Type
go Type
t1) Name
nm (Type -> Type
go Type
t2)
      go (UInfixT t1 :: Type
t1 nm :: Name
nm t2 :: Type
t2) = Type -> Name -> Type -> Type
UInfixT (Type -> Type
go Type
t1) Name
nm (Type -> Type
go Type
t2)
      go (ParensT t1 :: Type
t1) = Type -> Type
ParensT (Type -> Type
go Type
t1)
      go s :: Type
s@Type
WildCardT = Type
s
#endif

#if MIN_VERSION_template_haskell(2,12,0)
      go s :: Type
s@(UnboxedSumT _) = Type
s
#endif

#if MIN_VERSION_template_haskell(2,15,0)
      go (AppKindT ty :: Type
ty ki :: Type
ki) = Type -> Type -> Type
AppKindT (Type -> Type
go Type
ty) (Type -> Type
go Type
ki)
      go (ImplicitParamT n :: String
n ty :: Type
ty) = String -> Type -> Type
ImplicitParamT String
n (Type -> Type
go Type
ty)
#endif

#if MIN_VERSION_template_haskell(2,16,0)
      go (ForallVisT vars body) =
          commonForallCase vt vars $ \vts' vars' ->
          ForallVisT vars' (doSubsts vts' body)
#endif

#if MIN_VERSION_template_haskell(2,17,0)
      go MulArrowT = MulArrowT
#endif

-- testCapture :: Type
-- testCapture =
--     let
--         n = mkName
--         v = VarT . mkName
--     in
--       substInType (n "x", v "y" `AppT` v "z")
--                   (ForallT
--                    [n "y",n "z"]
--                    [ConT (mkName "Show") `AppT` v "x" `AppT` v "z"]
--                    (v "x" `AppT` v "y"))


#if !MIN_VERSION_template_haskell(2,10,0)
instance SubstTypeVariable Pred where
    subst s = mapPred (subst s)
#endif

#if !MIN_VERSION_template_haskell(2,8,0)
instance SubstTypeVariable Kind where
    subst _ = id -- No kind variables on old versions of GHC
#endif

-- | Make a name (based on the first arg) that's distinct from every name in the second arg
--
-- Example why this is necessary:
--
-- > type E x = forall y. Either x y
-- >
-- > ... expandSyns [t| forall y. y -> E y |]
--
-- The example as given may actually work correctly without any special capture-avoidance depending
-- on how GHC handles the @y@s, but in any case, the input type to expandSyns may be an explicit
-- AST using 'mkName' to ensure a collision.
--
evade :: Data d => Name -> d -> Name
evade :: Name -> d -> Name
evade n :: Name
n t :: d
t =
    let
        vars :: Set.Set Name
        vars :: Set Name
vars = (Set Name -> Set Name -> Set Name)
-> GenericQ (Set Name) -> d -> Set Name
forall r. (r -> r -> r) -> GenericQ r -> GenericQ r
everything Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
Set.union (Set Name -> (Name -> Set Name) -> a -> Set Name
forall a b r. (Typeable a, Typeable b) => r -> (b -> r) -> a -> r
mkQ Set Name
forall a. Set a
Set.empty Name -> Set Name
forall a. a -> Set a
Set.singleton) d
t

        go :: Name -> Name
go n1 :: Name
n1 = if Name
n1 Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set Name
vars
                then Name -> Name
go (Name -> Name
bump Name
n1)
                else Name
n1

        bump :: Name -> Name
bump = String -> Name
mkName (String -> Name) -> (Name -> String) -> Name -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ('f'Char -> String -> String
forall a. a -> [a] -> [a]
:) (String -> String) -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameBase
    in
      Name -> Name
go Name
n

-- | Make a list of names (based on the first arg) such that every name in the result
-- is distinct from every name in the second arg, and from the other results
evades :: (Data t) => [Name] -> t -> [Name]
evades :: [Name] -> t -> [Name]
evades ns :: [Name]
ns t :: t
t = (Name -> [Name] -> [Name]) -> [Name] -> [Name] -> [Name]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Name -> [Name] -> [Name]
c [] [Name]
ns
    where
      c :: Name -> [Name] -> [Name]
c n :: Name
n rec :: [Name]
rec = Name -> ([Name], t) -> Name
forall d. Data d => Name -> d -> Name
evade Name
n ([Name]
rec,t
t) Name -> [Name] -> [Name]
forall a. a -> [a] -> [a]
: [Name]
rec

-- evadeTest = let v = mkName "x"
--             in
--               evade v (AppT (VarT v) (VarT (mkName "fx")))

instance SubstTypeVariable Con where
  subst :: (Name, Type) -> Con -> Con
subst vt :: (Name, Type)
vt = Con -> Con
go
    where
      st :: a -> a
st = (Name, Type) -> a -> a
forall a. SubstTypeVariable a => (Name, Type) -> a -> a
subst (Name, Type)
vt

      go :: Con -> Con
go (NormalC n :: Name
n ts :: [BangType]
ts) = Name -> [BangType] -> Con
NormalC Name
n [(Bang
x, Type -> Type
forall a. SubstTypeVariable a => a -> a
st Type
y) | (x :: Bang
x,y :: Type
y) <- [BangType]
ts]
      go (RecC n :: Name
n ts :: [VarBangType]
ts) = Name -> [VarBangType] -> Con
RecC Name
n [(Name
x, Bang
y, Type -> Type
forall a. SubstTypeVariable a => a -> a
st Type
z) | (x :: Name
x,y :: Bang
y,z :: Type
z) <- [VarBangType]
ts]
      go (InfixC (y1 :: Bang
y1,t1 :: Type
t1) op :: Name
op (y2 :: Bang
y2,t2 :: Type
t2)) = BangType -> Name -> BangType -> Con
InfixC (Bang
y1,Type -> Type
forall a. SubstTypeVariable a => a -> a
st Type
t1) Name
op (Bang
y2,Type -> Type
forall a. SubstTypeVariable a => a -> a
st Type
t2)
      go (ForallC vars :: [TyVarBndr_ flag]
vars cxt :: [Type]
cxt body :: Con
body) =
          (Name, Type)
-> [TyVarBndr_ flag]
-> ([(Name, Type)] -> [TyVarBndr_ flag] -> Con)
-> Con
forall flag a.
(Name, Type)
-> [TyVarBndr_ flag]
-> ([(Name, Type)] -> [TyVarBndr_ flag] -> a)
-> a
commonForallCase (Name, Type)
vt [TyVarBndr_ flag]
vars (([(Name, Type)] -> [TyVarBndr_ flag] -> Con) -> Con)
-> ([(Name, Type)] -> [TyVarBndr_ flag] -> Con) -> Con
forall a b. (a -> b) -> a -> b
$ \vts' :: [(Name, Type)]
vts' vars' :: [TyVarBndr_ flag]
vars' ->
          [TyVarBndr_ flag] -> [Type] -> Con -> Con
ForallC [TyVarBndr_ flag]
vars' ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map ([(Name, Type)] -> Type -> Type
forall a. SubstTypeVariable a => [(Name, Type)] -> a -> a
doSubsts [(Name, Type)]
vts') [Type]
cxt) ([(Name, Type)] -> Con -> Con
forall a. SubstTypeVariable a => [(Name, Type)] -> a -> a
doSubsts [(Name, Type)]
vts' Con
body)
#if MIN_VERSION_template_haskell(2,11,0)
      go c :: Con
c@GadtC{} = Con -> Con
forall a a. Ppr a => a -> a
errGadt Con
c
      go c :: Con
c@RecGadtC{} = Con -> Con
forall a a. Ppr a => a -> a
errGadt Con
c

      errGadt :: a -> a
errGadt c :: a
c = String -> a
forall a. HasCallStack => String -> a
error (String
packagenameString -> String -> String
forall a. [a] -> [a] -> [a]
++": substInCon currently doesn't support GADT constructors with GHC >= 8 ("String -> String -> String
forall a. [a] -> [a] -> [a]
++a -> String
forall a. Ppr a => a -> String
pprint a
cString -> String -> String
forall a. [a] -> [a] -> [a]
++")")
#endif

-- Apply a substitution to something underneath a @forall@. The continuation
-- argument provides new substitutions and fresh type variable binders to avoid
-- the outer substitution from capturing the thing underneath the @forall@.
commonForallCase :: (Name, Type) -> [TyVarBndr_ flag]
                 -> ([(Name, Type)] -> [TyVarBndr_ flag] -> a)
                 -> a
commonForallCase :: (Name, Type)
-> [TyVarBndr_ flag]
-> ([(Name, Type)] -> [TyVarBndr_ flag] -> a)
-> a
commonForallCase vt :: (Name, Type)
vt@(v :: Name
v,t :: Type
t) bndrs :: [TyVarBndr_ flag]
bndrs k :: [(Name, Type)] -> [TyVarBndr_ flag] -> a
k
            -- If a variable with the same name as the one to be replaced is bound by the forall,
            -- the variable to be replaced is shadowed in the body, so we leave the whole thing alone (no recursion)
          | Name
v Name -> [Name] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (TyVarBndr_ flag -> Name
forall flag. TyVarBndr_ flag -> Name
tvName (TyVarBndr_ flag -> Name) -> [TyVarBndr_ flag] -> [Name]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TyVarBndr_ flag]
bndrs) = [(Name, Type)] -> [TyVarBndr_ flag] -> a
k [(Name, Type)
vt] [TyVarBndr_ flag]
bndrs

          | Bool
otherwise =
              let
                  -- prevent capture
                  vars :: [Name]
vars = TyVarBndr_ flag -> Name
forall flag. TyVarBndr_ flag -> Name
tvName (TyVarBndr_ flag -> Name) -> [TyVarBndr_ flag] -> [Name]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TyVarBndr_ flag]
bndrs
                  freshes :: [Name]
freshes = [Name] -> Type -> [Name]
forall t. Data t => [Name] -> t -> [Name]
evades [Name]
vars Type
t
                  freshTyVarBndrs :: [TyVarBndr_ flag]
freshTyVarBndrs = (Name -> TyVarBndr_ flag -> TyVarBndr_ flag)
-> [Name] -> [TyVarBndr_ flag] -> [TyVarBndr_ flag]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Name -> TyVarBndr_ flag -> TyVarBndr_ flag
forall flag. Name -> TyVarBndr_ flag -> TyVarBndr_ flag
tyVarBndrSetName [Name]
freshes [TyVarBndr_ flag]
bndrs
                  substs :: [(Name, Type)]
substs = [Name] -> [Type] -> [(Name, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
vars (Name -> Type
VarT (Name -> Type) -> [Name] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
freshes)
              in
                [(Name, Type)] -> [TyVarBndr_ flag] -> a
k ((Name, Type)
vt(Name, Type) -> [(Name, Type)] -> [(Name, Type)]
forall a. a -> [a] -> [a]
:[(Name, Type)]
substs) [TyVarBndr_ flag]
forall flag. [TyVarBndr_ flag]
freshTyVarBndrs

-- Apply multiple substitutions.
doSubsts :: SubstTypeVariable a => [(Name, Type)] -> a -> a
doSubsts :: [(Name, Type)] -> a -> a
doSubsts substs :: [(Name, Type)]
substs x :: a
x = ((Name, Type) -> a -> a) -> a -> [(Name, Type)] -> a
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (Name, Type) -> a -> a
forall a. SubstTypeVariable a => (Name, Type) -> a -> a
subst a
x [(Name, Type)]
substs

-- | Capture-free substitution
substInType :: (Name,Type) -> Type -> Type
substInType :: (Name, Type) -> Type -> Type
substInType = (Name, Type) -> Type -> Type
forall a. SubstTypeVariable a => (Name, Type) -> a -> a
subst

-- | Capture-free substitution
substInCon :: (Name,Type) -> Con -> Con
substInCon :: (Name, Type) -> Con -> Con
substInCon = (Name, Type) -> Con -> Con
forall a. SubstTypeVariable a => (Name, Type) -> a -> a
subst