{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- SPDX-License-Identifier: BSD-3-Clause
--
-- Description: Naive (slow) substitution-based implementation of
-- unification.  Uses a simple but expensive-to-maintain invariant on
-- substitutions, and returns a substitution from unification which
-- must then be composed with the substitution being tracked.
--
-- Not used in Swarm, and also unmaintained
-- (e.g. "Swarm.Effect.Unify.Fast" now supports expanding type
-- aliases + recursive types; this module does not). It's still here just for
-- testing/comparison.
module Swarm.Effect.Unify.Naive where

import Control.Algebra
import Control.Applicative (Alternative)
import Control.Carrier.State.Strict (StateC, evalState)
import Control.Carrier.Throw.Either (ThrowC, runThrow)
import Control.Category ((>>>))
import Control.Effect.State (get, gets, modify)
import Control.Effect.Throw (Throw, throwError)
import Control.Monad (zipWithM)
import Control.Monad.Free
import Control.Monad.Trans (MonadIO)
import Data.Function (on)
import Data.Map ((!?))
import Data.Map qualified as M
import Data.Map.Merge.Lazy qualified as M
import Data.Maybe (fromMaybe)
import Data.Set qualified as S
import Swarm.Effect.Unify
import Swarm.Effect.Unify.Common
import Swarm.Language.Types hiding (Type)

------------------------------------------------------------
-- Substitutions

-- | Class of things supporting substitution.  @Substitutes n b a@ means
--   that we can apply a substitution of type @Subst n b@ to a
--   value of type @a@, replacing all the free names of type @n@
--   inside the @a@ with values of type @b@, resulting in a new value
--   of type @a@.
class Substitutes n b a where
  subst :: Subst n b -> a -> a

-- | We can perform substitution on terms built up as the free monad
--   over a structure functor @f@.
instance (Show n, Ord n, Functor f) => Substitutes n (Free f n) (Free f n) where
  subst :: Subst n (Free f n) -> Free f n -> Free f n
subst Subst n (Free f n)
s Free f n
f = Free f n
f Free f n -> (n -> Free f n) -> Free f n
forall a b. Free f a -> (a -> Free f b) -> Free f b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \n
n -> Free f n -> Maybe (Free f n) -> Free f n
forall a. a -> Maybe a -> a
fromMaybe (n -> Free f n
forall (f :: * -> *) a. a -> Free f a
Pure n
n) (Subst n (Free f n) -> Map n (Free f n)
forall n a. Subst n a -> Map n a
getSubst Subst n (Free f n)
s Map n (Free f n) -> n -> Maybe (Free f n)
forall k a. Ord k => Map k a -> k -> Maybe a
!? n
n)

-- | Compose two substitutions.  Applying @s1 \@\@ s2@ is the same as
--   applying first @s2@, then @s1@; that is, semantically,
--   composition of substitutions corresponds exactly to function
--   composition when they are considered as functions on terms.
--
--   As one would expect, composition is associative and has 'idS' as
--   its identity.
(@@) :: (Ord n, Substitutes n a a) => Subst n a -> Subst n a -> Subst n a
(Subst Map n a
s1) @@ :: forall n a.
(Ord n, Substitutes n a a) =>
Subst n a -> Subst n a -> Subst n a
@@ (Subst Map n a
s2) = Map n a -> Subst n a
forall n a. Map n a -> Subst n a
Subst ((a -> a) -> Map n a -> Map n a
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Subst n a -> a -> a
forall n b a. Substitutes n b a => Subst n b -> a -> a
subst (Map n a -> Subst n a
forall n a. Map n a -> Subst n a
Subst Map n a
s1)) Map n a
s2 Map n a -> Map n a -> Map n a
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Map n a
s1)

-- | Compose a whole container of substitutions.  For example,
--   @compose [s1, s2, s3] = s1 \@\@ s2 \@\@ s3@.
compose :: (Ord n, Substitutes n a a, Foldable t) => t (Subst n a) -> Subst n a
compose :: forall n a (t :: * -> *).
(Ord n, Substitutes n a a, Foldable t) =>
t (Subst n a) -> Subst n a
compose = (Subst n a -> Subst n a -> Subst n a)
-> Subst n a -> t (Subst n a) -> Subst n a
forall a b. (a -> b -> b) -> b -> t a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst n a -> Subst n a -> Subst n a
forall n a.
(Ord n, Substitutes n a a) =>
Subst n a -> Subst n a -> Subst n a
(@@) Subst n a
forall n a. Subst n a
idS

------------------------------------------------------------
-- Carrier type

-- Note: this carrier type and the runUnification function are
-- identical between this module and Swarm.Effect.Unify.Fast, but it
-- seemed best to duplicate it, so we can modify the carriers
-- independently in the future if we want.

-- | Carrier type for unification: we maintain a current substitution,
--   a counter for generating fresh unification variables, and can
--   throw unification errors.
newtype UnificationC m a = UnificationC
  { forall (m :: * -> *) a.
UnificationC m a
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     a
unUnificationC ::
      StateC (Subst IntVar UType) (StateC FreshVarCounter (ThrowC UnificationError m)) a
  }
  deriving newtype ((forall a b. (a -> b) -> UnificationC m a -> UnificationC m b)
-> (forall a b. a -> UnificationC m b -> UnificationC m a)
-> Functor (UnificationC m)
forall a b. a -> UnificationC m b -> UnificationC m a
forall a b. (a -> b) -> UnificationC m a -> UnificationC m b
forall (m :: * -> *) a b.
Functor m =>
a -> UnificationC m b -> UnificationC m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> UnificationC m a -> UnificationC m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> UnificationC m a -> UnificationC m b
fmap :: forall a b. (a -> b) -> UnificationC m a -> UnificationC m b
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> UnificationC m b -> UnificationC m a
<$ :: forall a b. a -> UnificationC m b -> UnificationC m a
Functor, Functor (UnificationC m)
Functor (UnificationC m) =>
(forall a. a -> UnificationC m a)
-> (forall a b.
    UnificationC m (a -> b) -> UnificationC m a -> UnificationC m b)
-> (forall a b c.
    (a -> b -> c)
    -> UnificationC m a -> UnificationC m b -> UnificationC m c)
-> (forall a b.
    UnificationC m a -> UnificationC m b -> UnificationC m b)
-> (forall a b.
    UnificationC m a -> UnificationC m b -> UnificationC m a)
-> Applicative (UnificationC m)
forall a. a -> UnificationC m a
forall a b.
UnificationC m a -> UnificationC m b -> UnificationC m a
forall a b.
UnificationC m a -> UnificationC m b -> UnificationC m b
forall a b.
UnificationC m (a -> b) -> UnificationC m a -> UnificationC m b
forall a b c.
(a -> b -> c)
-> UnificationC m a -> UnificationC m b -> UnificationC m c
forall (m :: * -> *). Monad m => Functor (UnificationC m)
forall (m :: * -> *) a. Monad m => a -> UnificationC m a
forall (m :: * -> *) a b.
Monad m =>
UnificationC m a -> UnificationC m b -> UnificationC m a
forall (m :: * -> *) a b.
Monad m =>
UnificationC m a -> UnificationC m b -> UnificationC m b
forall (m :: * -> *) a b.
Monad m =>
UnificationC m (a -> b) -> UnificationC m a -> UnificationC m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> UnificationC m a -> UnificationC m b -> UnificationC m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall (m :: * -> *) a. Monad m => a -> UnificationC m a
pure :: forall a. a -> UnificationC m a
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
UnificationC m (a -> b) -> UnificationC m a -> UnificationC m b
<*> :: forall a b.
UnificationC m (a -> b) -> UnificationC m a -> UnificationC m b
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> UnificationC m a -> UnificationC m b -> UnificationC m c
liftA2 :: forall a b c.
(a -> b -> c)
-> UnificationC m a -> UnificationC m b -> UnificationC m c
$c*> :: forall (m :: * -> *) a b.
Monad m =>
UnificationC m a -> UnificationC m b -> UnificationC m b
*> :: forall a b.
UnificationC m a -> UnificationC m b -> UnificationC m b
$c<* :: forall (m :: * -> *) a b.
Monad m =>
UnificationC m a -> UnificationC m b -> UnificationC m a
<* :: forall a b.
UnificationC m a -> UnificationC m b -> UnificationC m a
Applicative, Applicative (UnificationC m)
Applicative (UnificationC m) =>
(forall a. UnificationC m a)
-> (forall a.
    UnificationC m a -> UnificationC m a -> UnificationC m a)
-> (forall a. UnificationC m a -> UnificationC m [a])
-> (forall a. UnificationC m a -> UnificationC m [a])
-> Alternative (UnificationC m)
forall a. UnificationC m a
forall a. UnificationC m a -> UnificationC m [a]
forall a. UnificationC m a -> UnificationC m a -> UnificationC m a
forall (f :: * -> *).
Applicative f =>
(forall a. f a)
-> (forall a. f a -> f a -> f a)
-> (forall a. f a -> f [a])
-> (forall a. f a -> f [a])
-> Alternative f
forall (m :: * -> *).
(Alternative m, Monad m) =>
Applicative (UnificationC m)
forall (m :: * -> *) a.
(Alternative m, Monad m) =>
UnificationC m a
forall (m :: * -> *) a.
(Alternative m, Monad m) =>
UnificationC m a -> UnificationC m [a]
forall (m :: * -> *) a.
(Alternative m, Monad m) =>
UnificationC m a -> UnificationC m a -> UnificationC m a
$cempty :: forall (m :: * -> *) a.
(Alternative m, Monad m) =>
UnificationC m a
empty :: forall a. UnificationC m a
$c<|> :: forall (m :: * -> *) a.
(Alternative m, Monad m) =>
UnificationC m a -> UnificationC m a -> UnificationC m a
<|> :: forall a. UnificationC m a -> UnificationC m a -> UnificationC m a
$csome :: forall (m :: * -> *) a.
(Alternative m, Monad m) =>
UnificationC m a -> UnificationC m [a]
some :: forall a. UnificationC m a -> UnificationC m [a]
$cmany :: forall (m :: * -> *) a.
(Alternative m, Monad m) =>
UnificationC m a -> UnificationC m [a]
many :: forall a. UnificationC m a -> UnificationC m [a]
Alternative, Applicative (UnificationC m)
Applicative (UnificationC m) =>
(forall a b.
 UnificationC m a -> (a -> UnificationC m b) -> UnificationC m b)
-> (forall a b.
    UnificationC m a -> UnificationC m b -> UnificationC m b)
-> (forall a. a -> UnificationC m a)
-> Monad (UnificationC m)
forall a. a -> UnificationC m a
forall a b.
UnificationC m a -> UnificationC m b -> UnificationC m b
forall a b.
UnificationC m a -> (a -> UnificationC m b) -> UnificationC m b
forall (m :: * -> *). Monad m => Applicative (UnificationC m)
forall (m :: * -> *) a. Monad m => a -> UnificationC m a
forall (m :: * -> *) a b.
Monad m =>
UnificationC m a -> UnificationC m b -> UnificationC m b
forall (m :: * -> *) a b.
Monad m =>
UnificationC m a -> (a -> UnificationC m b) -> UnificationC m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
UnificationC m a -> (a -> UnificationC m b) -> UnificationC m b
>>= :: forall a b.
UnificationC m a -> (a -> UnificationC m b) -> UnificationC m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
UnificationC m a -> UnificationC m b -> UnificationC m b
>> :: forall a b.
UnificationC m a -> UnificationC m b -> UnificationC m b
$creturn :: forall (m :: * -> *) a. Monad m => a -> UnificationC m a
return :: forall a. a -> UnificationC m a
Monad, Monad (UnificationC m)
Monad (UnificationC m) =>
(forall a. IO a -> UnificationC m a) -> MonadIO (UnificationC m)
forall a. IO a -> UnificationC m a
forall (m :: * -> *).
Monad m =>
(forall a. IO a -> m a) -> MonadIO m
forall (m :: * -> *). MonadIO m => Monad (UnificationC m)
forall (m :: * -> *) a. MonadIO m => IO a -> UnificationC m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> UnificationC m a
liftIO :: forall a. IO a -> UnificationC m a
MonadIO)

-- | Counter for generating fresh unification variables.
newtype FreshVarCounter = FreshVarCounter {FreshVarCounter -> Int
getFreshVarCounter :: Int}
  deriving (FreshVarCounter -> FreshVarCounter -> Bool
(FreshVarCounter -> FreshVarCounter -> Bool)
-> (FreshVarCounter -> FreshVarCounter -> Bool)
-> Eq FreshVarCounter
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: FreshVarCounter -> FreshVarCounter -> Bool
== :: FreshVarCounter -> FreshVarCounter -> Bool
$c/= :: FreshVarCounter -> FreshVarCounter -> Bool
/= :: FreshVarCounter -> FreshVarCounter -> Bool
Eq, Eq FreshVarCounter
Eq FreshVarCounter =>
(FreshVarCounter -> FreshVarCounter -> Ordering)
-> (FreshVarCounter -> FreshVarCounter -> Bool)
-> (FreshVarCounter -> FreshVarCounter -> Bool)
-> (FreshVarCounter -> FreshVarCounter -> Bool)
-> (FreshVarCounter -> FreshVarCounter -> Bool)
-> (FreshVarCounter -> FreshVarCounter -> FreshVarCounter)
-> (FreshVarCounter -> FreshVarCounter -> FreshVarCounter)
-> Ord FreshVarCounter
FreshVarCounter -> FreshVarCounter -> Bool
FreshVarCounter -> FreshVarCounter -> Ordering
FreshVarCounter -> FreshVarCounter -> FreshVarCounter
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: FreshVarCounter -> FreshVarCounter -> Ordering
compare :: FreshVarCounter -> FreshVarCounter -> Ordering
$c< :: FreshVarCounter -> FreshVarCounter -> Bool
< :: FreshVarCounter -> FreshVarCounter -> Bool
$c<= :: FreshVarCounter -> FreshVarCounter -> Bool
<= :: FreshVarCounter -> FreshVarCounter -> Bool
$c> :: FreshVarCounter -> FreshVarCounter -> Bool
> :: FreshVarCounter -> FreshVarCounter -> Bool
$c>= :: FreshVarCounter -> FreshVarCounter -> Bool
>= :: FreshVarCounter -> FreshVarCounter -> Bool
$cmax :: FreshVarCounter -> FreshVarCounter -> FreshVarCounter
max :: FreshVarCounter -> FreshVarCounter -> FreshVarCounter
$cmin :: FreshVarCounter -> FreshVarCounter -> FreshVarCounter
min :: FreshVarCounter -> FreshVarCounter -> FreshVarCounter
Ord, Int -> FreshVarCounter
FreshVarCounter -> Int
FreshVarCounter -> [FreshVarCounter]
FreshVarCounter -> FreshVarCounter
FreshVarCounter -> FreshVarCounter -> [FreshVarCounter]
FreshVarCounter
-> FreshVarCounter -> FreshVarCounter -> [FreshVarCounter]
(FreshVarCounter -> FreshVarCounter)
-> (FreshVarCounter -> FreshVarCounter)
-> (Int -> FreshVarCounter)
-> (FreshVarCounter -> Int)
-> (FreshVarCounter -> [FreshVarCounter])
-> (FreshVarCounter -> FreshVarCounter -> [FreshVarCounter])
-> (FreshVarCounter -> FreshVarCounter -> [FreshVarCounter])
-> (FreshVarCounter
    -> FreshVarCounter -> FreshVarCounter -> [FreshVarCounter])
-> Enum FreshVarCounter
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: FreshVarCounter -> FreshVarCounter
succ :: FreshVarCounter -> FreshVarCounter
$cpred :: FreshVarCounter -> FreshVarCounter
pred :: FreshVarCounter -> FreshVarCounter
$ctoEnum :: Int -> FreshVarCounter
toEnum :: Int -> FreshVarCounter
$cfromEnum :: FreshVarCounter -> Int
fromEnum :: FreshVarCounter -> Int
$cenumFrom :: FreshVarCounter -> [FreshVarCounter]
enumFrom :: FreshVarCounter -> [FreshVarCounter]
$cenumFromThen :: FreshVarCounter -> FreshVarCounter -> [FreshVarCounter]
enumFromThen :: FreshVarCounter -> FreshVarCounter -> [FreshVarCounter]
$cenumFromTo :: FreshVarCounter -> FreshVarCounter -> [FreshVarCounter]
enumFromTo :: FreshVarCounter -> FreshVarCounter -> [FreshVarCounter]
$cenumFromThenTo :: FreshVarCounter
-> FreshVarCounter -> FreshVarCounter -> [FreshVarCounter]
enumFromThenTo :: FreshVarCounter
-> FreshVarCounter -> FreshVarCounter -> [FreshVarCounter]
Enum)

-- | Run a 'Unification' effect via the 'UnificationC' carrier.
runUnification :: Algebra sig m => UnificationC m a -> m (Either UnificationError a)
runUnification :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Algebra sig m =>
UnificationC m a -> m (Either UnificationError a)
runUnification =
  UnificationC m a
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     a
forall (m :: * -> *) a.
UnificationC m a
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     a
unUnificationC (UnificationC m a
 -> StateC
      (Subst IntVar UType)
      (StateC FreshVarCounter (ThrowC UnificationError m))
      a)
-> (StateC
      (Subst IntVar UType)
      (StateC FreshVarCounter (ThrowC UnificationError m))
      a
    -> m (Either UnificationError a))
-> UnificationC m a
-> m (Either UnificationError a)
forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> Subst IntVar UType
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     a
-> StateC FreshVarCounter (ThrowC UnificationError m) a
forall s (m :: * -> *) a. Functor m => s -> StateC s m a -> m a
evalState Subst IntVar UType
forall n a. Subst n a
idS (StateC
   (Subst IntVar UType)
   (StateC FreshVarCounter (ThrowC UnificationError m))
   a
 -> StateC FreshVarCounter (ThrowC UnificationError m) a)
-> (StateC FreshVarCounter (ThrowC UnificationError m) a
    -> m (Either UnificationError a))
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     a
-> m (Either UnificationError a)
forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> FreshVarCounter
-> StateC FreshVarCounter (ThrowC UnificationError m) a
-> ThrowC UnificationError m a
forall s (m :: * -> *) a. Functor m => s -> StateC s m a -> m a
evalState (Int -> FreshVarCounter
FreshVarCounter Int
0) (StateC FreshVarCounter (ThrowC UnificationError m) a
 -> ThrowC UnificationError m a)
-> (ThrowC UnificationError m a -> m (Either UnificationError a))
-> StateC FreshVarCounter (ThrowC UnificationError m) a
-> m (Either UnificationError a)
forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> ThrowC UnificationError m a -> m (Either UnificationError a)
forall e (m :: * -> *) a. ThrowC e m a -> m (Either e a)
runThrow

------------------------------------------------------------
-- Unification

-- | Naive implementation of the 'Unification' effect in terms of the
--   'UnificationC' carrier.
--
--   We maintain an invariant on the current @Subst@ that map keys
--   never show up in any of the values.  For example, we could have
--   @{x -> a+5, y -> 5}@ but not @{x -> a+y, y -> 5}@.
instance Algebra sig m => Algebra (Unification :+: sig) (UnificationC m) where
  alg :: forall (ctx :: * -> *) (n :: * -> *) a.
Functor ctx =>
Handler ctx n (UnificationC m)
-> (:+:) Unification sig n a -> ctx () -> UnificationC m (ctx a)
alg Handler ctx n (UnificationC m)
hdl (:+:) Unification sig n a
sig ctx ()
ctx = StateC
  (Subst IntVar UType)
  (StateC FreshVarCounter (ThrowC UnificationError m))
  (ctx a)
-> UnificationC m (ctx a)
forall (m :: * -> *) a.
StateC
  (Subst IntVar UType)
  (StateC FreshVarCounter (ThrowC UnificationError m))
  a
-> UnificationC m a
UnificationC (StateC
   (Subst IntVar UType)
   (StateC FreshVarCounter (ThrowC UnificationError m))
   (ctx a)
 -> UnificationC m (ctx a))
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     (ctx a)
-> UnificationC m (ctx a)
forall a b. (a -> b) -> a -> b
$ case (:+:) Unification sig n a
sig of
    L (Unify UType
t1 UType
t2) -> do
      Subst IntVar UType
s1 <- forall s (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (State s) sig m =>
m s
get @(Subst IntVar UType)
      let t1' :: UType
t1' = Subst IntVar UType -> UType -> UType
forall n b a. Substitutes n b a => Subst n b -> a -> a
subst Subst IntVar UType
s1 UType
t1
          t2' :: UType
t2' = Subst IntVar UType -> UType -> UType
forall n b a. Substitutes n b a => Subst n b -> a -> a
subst Subst IntVar UType
s1 UType
t2
      Subst IntVar UType
s2 <- UType
-> UType
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     (Subst IntVar UType)
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw UnificationError) sig m =>
UType -> UType -> m (Subst IntVar UType)
unify UType
t1' UType
t2'
      (Subst IntVar UType -> Subst IntVar UType)
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     ()
forall s (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (State s) sig m =>
(s -> s) -> m ()
modify (Subst IntVar UType
s2 Subst IntVar UType -> Subst IntVar UType -> Subst IntVar UType
forall n a.
(Ord n, Substitutes n a a) =>
Subst n a -> Subst n a -> Subst n a
@@)
      ctx a
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     (ctx a)
forall a.
a
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     a
forall (m :: * -> *) a. Monad m => a -> m a
return (ctx a
 -> StateC
      (Subst IntVar UType)
      (StateC FreshVarCounter (ThrowC UnificationError m))
      (ctx a))
-> ctx a
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     (ctx a)
forall a b. (a -> b) -> a -> b
$ UType -> Either UnificationError UType
forall a b. b -> Either a b
Right (Subst IntVar UType -> UType -> UType
forall n b a. Substitutes n b a => Subst n b -> a -> a
subst Subst IntVar UType
s2 UType
t1') a -> ctx () -> ctx a
forall a b. a -> ctx b -> ctx a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ ctx ()
ctx
    L (ApplyBindings UType
t) -> do
      Subst IntVar UType
s <- forall s (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (State s) sig m =>
m s
get @(Subst IntVar UType)
      ctx a
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     (ctx a)
forall a.
a
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     a
forall (m :: * -> *) a. Monad m => a -> m a
return (ctx a
 -> StateC
      (Subst IntVar UType)
      (StateC FreshVarCounter (ThrowC UnificationError m))
      (ctx a))
-> ctx a
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     (ctx a)
forall a b. (a -> b) -> a -> b
$ Subst IntVar UType -> a -> a
forall n b a. Substitutes n b a => Subst n b -> a -> a
subst Subst IntVar UType
s a
UType
t a -> ctx () -> ctx a
forall a b. a -> ctx b -> ctx a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ ctx ()
ctx
    L Unification n a
FreshIntVar -> do
      IntVar
v <- Int -> IntVar
IntVar (Int -> IntVar)
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     Int
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     IntVar
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (FreshVarCounter -> Int)
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     Int
forall s (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (State s) sig m =>
(s -> a) -> m a
gets FreshVarCounter -> Int
getFreshVarCounter
      forall s (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (State s) sig m =>
(s -> s) -> m ()
modify @FreshVarCounter FreshVarCounter -> FreshVarCounter
forall a. Enum a => a -> a
succ
      ctx a
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     (ctx a)
forall a.
a
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     a
forall (m :: * -> *) a. Monad m => a -> m a
return (ctx a
 -> StateC
      (Subst IntVar UType)
      (StateC FreshVarCounter (ThrowC UnificationError m))
      (ctx a))
-> ctx a
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     (ctx a)
forall a b. (a -> b) -> a -> b
$ a
IntVar
v a -> ctx () -> ctx a
forall a b. a -> ctx b -> ctx a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ ctx ()
ctx
    L (FreeUVars UType
t) -> do
      Subst IntVar UType
s <- forall s (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (State s) sig m =>
m s
get @(Subst IntVar UType)
      ctx a
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     (ctx a)
forall a.
a
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     a
forall (m :: * -> *) a. Monad m => a -> m a
return (ctx a
 -> StateC
      (Subst IntVar UType)
      (StateC FreshVarCounter (ThrowC UnificationError m))
      (ctx a))
-> ctx a
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     (ctx a)
forall a b. (a -> b) -> a -> b
$ UType -> Set IntVar
fuvs (Subst IntVar UType -> UType -> UType
forall n b a. Substitutes n b a => Subst n b -> a -> a
subst Subst IntVar UType
s UType
t) a -> ctx () -> ctx a
forall a b. a -> ctx b -> ctx a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ ctx ()
ctx
    R sig n a
other -> Handler
  ctx
  n
  (StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m)))
-> (:+:)
     (State (Subst IntVar UType))
     (State FreshVarCounter :+: (Throw UnificationError :+: sig))
     n
     a
-> ctx ()
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     (ctx a)
forall (ctx :: * -> *) (n :: * -> *) a.
Functor ctx =>
Handler
  ctx
  n
  (StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m)))
-> (:+:)
     (State (Subst IntVar UType))
     (State FreshVarCounter :+: (Throw UnificationError :+: sig))
     n
     a
-> ctx ()
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     (ctx a)
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) (ctx :: * -> *)
       (n :: * -> *) a.
(Algebra sig m, Functor ctx) =>
Handler ctx n m -> sig n a -> ctx () -> m (ctx a)
alg (UnificationC m (ctx x)
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     (ctx x)
forall (m :: * -> *) a.
UnificationC m a
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     a
unUnificationC (UnificationC m (ctx x)
 -> StateC
      (Subst IntVar UType)
      (StateC FreshVarCounter (ThrowC UnificationError m))
      (ctx x))
-> (ctx (n x) -> UnificationC m (ctx x))
-> ctx (n x)
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     (ctx x)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ctx (n x) -> UnificationC m (ctx x)
Handler ctx n (UnificationC m)
hdl) ((:+:) (State FreshVarCounter) (Throw UnificationError :+: sig) n a
-> (:+:)
     (State (Subst IntVar UType))
     (State FreshVarCounter :+: (Throw UnificationError :+: sig))
     n
     a
forall (f :: (* -> *) -> * -> *) (g :: (* -> *) -> * -> *)
       (m :: * -> *) k.
g m k -> (:+:) f g m k
R ((:+:) (Throw UnificationError) sig n a
-> (:+:)
     (State FreshVarCounter) (Throw UnificationError :+: sig) n a
forall (f :: (* -> *) -> * -> *) (g :: (* -> *) -> * -> *)
       (m :: * -> *) k.
g m k -> (:+:) f g m k
R (sig n a -> (:+:) (Throw UnificationError) sig n a
forall (f :: (* -> *) -> * -> *) (g :: (* -> *) -> * -> *)
       (m :: * -> *) k.
g m k -> (:+:) f g m k
R sig n a
other))) ctx ()
ctx

-- | Unify two types and return the mgu, i.e. the smallest
--   substitution which makes them equal.
unify ::
  Has (Throw UnificationError) sig m =>
  UType ->
  UType ->
  m (Subst IntVar UType)
unify :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw UnificationError) sig m =>
UType -> UType -> m (Subst IntVar UType)
unify UType
ty1 UType
ty2 = case (UType
ty1, UType
ty2) of
  (Pure IntVar
x, Pure IntVar
y)
    | IntVar
x IntVar -> IntVar -> Bool
forall a. Eq a => a -> a -> Bool
== IntVar
y -> Subst IntVar UType -> m (Subst IntVar UType)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Subst IntVar UType
forall n a. Subst n a
idS
    | Bool
otherwise -> Subst IntVar UType -> m (Subst IntVar UType)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst IntVar UType -> m (Subst IntVar UType))
-> Subst IntVar UType -> m (Subst IntVar UType)
forall a b. (a -> b) -> a -> b
$ IntVar
x IntVar -> UType -> Subst IntVar UType
forall n a. n -> a -> Subst n a
|-> IntVar -> UType
forall (f :: * -> *) a. a -> Free f a
Pure IntVar
y
  (Pure IntVar
x, UType
y)
    | IntVar
x IntVar -> Set IntVar -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` UType -> Set IntVar
fuvs UType
y -> UnificationError -> m (Subst IntVar UType)
forall e (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw e) sig m =>
e -> m a
throwError (UnificationError -> m (Subst IntVar UType))
-> UnificationError -> m (Subst IntVar UType)
forall a b. (a -> b) -> a -> b
$ IntVar -> UType -> UnificationError
Infinite IntVar
x UType
y
    | Bool
otherwise -> Subst IntVar UType -> m (Subst IntVar UType)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst IntVar UType -> m (Subst IntVar UType))
-> Subst IntVar UType -> m (Subst IntVar UType)
forall a b. (a -> b) -> a -> b
$ IntVar
x IntVar -> UType -> Subst IntVar UType
forall n a. n -> a -> Subst n a
|-> UType
y
  (UType
y, Pure IntVar
x)
    | IntVar
x IntVar -> Set IntVar -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` UType -> Set IntVar
fuvs UType
y -> UnificationError -> m (Subst IntVar UType)
forall e (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw e) sig m =>
e -> m a
throwError (UnificationError -> m (Subst IntVar UType))
-> UnificationError -> m (Subst IntVar UType)
forall a b. (a -> b) -> a -> b
$ IntVar -> UType -> UnificationError
Infinite IntVar
x UType
y
    | Bool
otherwise -> Subst IntVar UType -> m (Subst IntVar UType)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst IntVar UType -> m (Subst IntVar UType))
-> Subst IntVar UType -> m (Subst IntVar UType)
forall a b. (a -> b) -> a -> b
$ IntVar
x IntVar -> UType -> Subst IntVar UType
forall n a. n -> a -> Subst n a
|-> UType
y
  (Free TypeF UType
t1, Free TypeF UType
t2) -> TypeF UType -> TypeF UType -> m (Subst IntVar UType)
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw UnificationError) sig m =>
TypeF UType -> TypeF UType -> m (Subst IntVar UType)
unifyF TypeF UType
t1 TypeF UType
t2

-- | Unify two non-variable terms and return an mgu, i.e. the smallest
--   substitution which makes them equal.
unifyF ::
  Has (Throw UnificationError) sig m =>
  TypeF UType ->
  TypeF UType ->
  m (Subst IntVar UType)
unifyF :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw UnificationError) sig m =>
TypeF UType -> TypeF UType -> m (Subst IntVar UType)
unifyF TypeF UType
t1 TypeF UType
t2 = case (TypeF UType
t1, TypeF UType
t2) of
  (TyConF TyCon
c1 [UType]
ts1, TyConF TyCon
c2 [UType]
ts2) -> case TyCon
c1 TyCon -> TyCon -> Bool
forall a. Eq a => a -> a -> Bool
== TyCon
c2 of
    Bool
True -> [Subst IntVar UType] -> Subst IntVar UType
forall n a (t :: * -> *).
(Ord n, Substitutes n a a, Foldable t) =>
t (Subst n a) -> Subst n a
compose ([Subst IntVar UType] -> Subst IntVar UType)
-> m [Subst IntVar UType] -> m (Subst IntVar UType)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (UType -> UType -> m (Subst IntVar UType))
-> [UType] -> [UType] -> m [Subst IntVar UType]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM UType -> UType -> m (Subst IntVar UType)
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw UnificationError) sig m =>
UType -> UType -> m (Subst IntVar UType)
unify [UType]
ts1 [UType]
ts2
    Bool
False -> m (Subst IntVar UType)
unifyErr
  (TyConF {}, TypeF UType
_) -> m (Subst IntVar UType)
unifyErr
  (TyVarF Var
_ Var
v1, TyVarF Var
_ Var
v2) -> case Var
v1 Var -> Var -> Bool
forall a. Eq a => a -> a -> Bool
== Var
v2 of
    Bool
True -> Subst IntVar UType -> m (Subst IntVar UType)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Subst IntVar UType
forall n a. Subst n a
idS
    Bool
False -> m (Subst IntVar UType)
unifyErr
  (TyVarF {}, TypeF UType
_) -> m (Subst IntVar UType)
unifyErr
  (TyRcdF Map Var UType
m1, TyRcdF Map Var UType
m2) ->
    case (Set Var -> Set Var -> Bool
forall a. Eq a => a -> a -> Bool
(==) (Set Var -> Set Var -> Bool)
-> (Map Var UType -> Set Var)
-> Map Var UType
-> Map Var UType
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` Map Var UType -> Set Var
forall k a. Map k a -> Set k
M.keysSet) Map Var UType
m1 Map Var UType
m2 of
      Bool
False -> m (Subst IntVar UType)
unifyErr
      Bool
_ -> ((Map Var (Subst IntVar UType) -> Subst IntVar UType)
-> m (Map Var (Subst IntVar UType)) -> m (Subst IntVar UType)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Map Var (Subst IntVar UType) -> Subst IntVar UType
forall n a (t :: * -> *).
(Ord n, Substitutes n a a, Foldable t) =>
t (Subst n a) -> Subst n a
compose (m (Map Var (Subst IntVar UType)) -> m (Subst IntVar UType))
-> (Map Var (m (Subst IntVar UType))
    -> m (Map Var (Subst IntVar UType)))
-> Map Var (m (Subst IntVar UType))
-> m (Subst IntVar UType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Var (m (Subst IntVar UType))
-> m (Map Var (Subst IntVar UType))
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => Map Var (m a) -> m (Map Var a)
sequence) (SimpleWhenMissing Var UType (m (Subst IntVar UType))
-> SimpleWhenMissing Var UType (m (Subst IntVar UType))
-> SimpleWhenMatched Var UType UType (m (Subst IntVar UType))
-> Map Var UType
-> Map Var UType
-> Map Var (m (Subst IntVar UType))
forall k a c b.
Ord k =>
SimpleWhenMissing k a c
-> SimpleWhenMissing k b c
-> SimpleWhenMatched k a b c
-> Map k a
-> Map k b
-> Map k c
M.merge SimpleWhenMissing Var UType (m (Subst IntVar UType))
forall (f :: * -> *) k x y. Applicative f => WhenMissing f k x y
M.dropMissing SimpleWhenMissing Var UType (m (Subst IntVar UType))
forall (f :: * -> *) k x y. Applicative f => WhenMissing f k x y
M.dropMissing ((Var -> UType -> UType -> m (Subst IntVar UType))
-> SimpleWhenMatched Var UType UType (m (Subst IntVar UType))
forall (f :: * -> *) k x y z.
Applicative f =>
(k -> x -> y -> z) -> WhenMatched f k x y z
M.zipWithMatched ((UType -> UType -> m (Subst IntVar UType))
-> Var -> UType -> UType -> m (Subst IntVar UType)
forall a b. a -> b -> a
const UType -> UType -> m (Subst IntVar UType)
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw UnificationError) sig m =>
UType -> UType -> m (Subst IntVar UType)
unify)) Map Var UType
m1 Map Var UType
m2)
  (TyRcdF {}, TypeF UType
_) -> m (Subst IntVar UType)
unifyErr
  -- Don't support any extra features (e.g. recursive types), so just
  -- add a catch-all failure case
  (TypeF UType
_, TypeF UType
_) -> m (Subst IntVar UType)
unifyErr
 where
  unifyErr :: m (Subst IntVar UType)
unifyErr = UnificationError -> m (Subst IntVar UType)
forall e (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw e) sig m =>
e -> m a
throwError (UnificationError -> m (Subst IntVar UType))
-> UnificationError -> m (Subst IntVar UType)
forall a b. (a -> b) -> a -> b
$ TypeF UType -> TypeF UType -> UnificationError
UnifyErr TypeF UType
t1 TypeF UType
t2