--  This part of the code comes from typed-protocols, I modified a few things.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module TypedSession.Codec where

import Control.Exception (Exception)
import TypedSession.Core

{- |
Function to encode Msg into bytes.
-}
newtype Encode role' ps bytes = Encode
  { forall role' ps bytes.
Encode role' ps bytes
-> forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
          (st'' :: ps).
   Msg role' ps st send st' recv st'' -> bytes
encode
      :: forall (send :: role') (recv :: role') (st :: ps) (st' :: ps) (st'' :: ps)
       . Msg role' ps st send st' recv st''
      -> bytes
  }

{- |
Incremental decoding function.
-}
newtype Decode role' ps failure bytes = Decode
  { forall role' ps failure bytes.
Decode role' ps failure bytes
-> DecodeStep bytes failure (AnyMsg role' ps)
decode :: DecodeStep bytes failure (AnyMsg role' ps)
  }

{- |
Generic incremental decoder constructor, you need to convert specific incremental decoders to it.
-}
data DecodeStep bytes failure a
  = DecodePartial (Maybe bytes -> (DecodeStep bytes failure a))
  | DecodeDone a (Maybe bytes)
  | DecodeFail failure

data CodecFailure
  = CodecFailureOutOfInput
  | CodecFailure String
  deriving (CodecFailure -> CodecFailure -> Bool
(CodecFailure -> CodecFailure -> Bool)
-> (CodecFailure -> CodecFailure -> Bool) -> Eq CodecFailure
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: CodecFailure -> CodecFailure -> Bool
== :: CodecFailure -> CodecFailure -> Bool
$c/= :: CodecFailure -> CodecFailure -> Bool
/= :: CodecFailure -> CodecFailure -> Bool
Eq, Int -> CodecFailure -> ShowS
[CodecFailure] -> ShowS
CodecFailure -> String
(Int -> CodecFailure -> ShowS)
-> (CodecFailure -> String)
-> ([CodecFailure] -> ShowS)
-> Show CodecFailure
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CodecFailure -> ShowS
showsPrec :: Int -> CodecFailure -> ShowS
$cshow :: CodecFailure -> String
show :: CodecFailure -> String
$cshowList :: [CodecFailure] -> ShowS
showList :: [CodecFailure] -> ShowS
Show)

instance Exception CodecFailure

{- |
Bottom functions for sending and receiving bytes.
-}
data Channel m bytes = Channel
  { forall (m :: * -> *) bytes. Channel m bytes -> bytes -> m ()
send :: bytes -> m ()
  , forall (m :: * -> *) bytes. Channel m bytes -> m (Maybe bytes)
recv :: m (Maybe bytes)
  }

{- |
Generic incremental decoding function.
-}
runDecoderWithChannel
  :: (Monad m)
  => Channel m bytes
  -> Maybe bytes
  -> DecodeStep bytes failure a
  -> m (Either failure (a, Maybe bytes))
runDecoderWithChannel :: forall (m :: * -> *) bytes failure a.
Monad m =>
Channel m bytes
-> Maybe bytes
-> DecodeStep bytes failure a
-> m (Either failure (a, Maybe bytes))
runDecoderWithChannel Channel{m (Maybe bytes)
recv :: forall (m :: * -> *) bytes. Channel m bytes -> m (Maybe bytes)
recv :: m (Maybe bytes)
recv} = Maybe bytes
-> DecodeStep bytes failure a
-> m (Either failure (a, Maybe bytes))
go
 where
  go :: Maybe bytes
-> DecodeStep bytes failure a
-> m (Either failure (a, Maybe bytes))
go Maybe bytes
_ (DecodeDone a
x Maybe bytes
trailing) = Either failure (a, Maybe bytes)
-> m (Either failure (a, Maybe bytes))
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ((a, Maybe bytes) -> Either failure (a, Maybe bytes)
forall a b. b -> Either a b
Right (a
x, Maybe bytes
trailing))
  go Maybe bytes
_ (DecodeFail failure
failure) = Either failure (a, Maybe bytes)
-> m (Either failure (a, Maybe bytes))
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (failure -> Either failure (a, Maybe bytes)
forall a b. a -> Either a b
Left failure
failure)
  go Maybe bytes
Nothing (DecodePartial Maybe bytes -> DecodeStep bytes failure a
k) = m (Maybe bytes)
recv m (Maybe bytes)
-> (Maybe bytes -> m (DecodeStep bytes failure a))
-> m (DecodeStep bytes failure a)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= DecodeStep bytes failure a -> m (DecodeStep bytes failure a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DecodeStep bytes failure a -> m (DecodeStep bytes failure a))
-> (Maybe bytes -> DecodeStep bytes failure a)
-> Maybe bytes
-> m (DecodeStep bytes failure a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe bytes -> DecodeStep bytes failure a
k m (DecodeStep bytes failure a)
-> (DecodeStep bytes failure a
    -> m (Either failure (a, Maybe bytes)))
-> m (Either failure (a, Maybe bytes))
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Maybe bytes
-> DecodeStep bytes failure a
-> m (Either failure (a, Maybe bytes))
go Maybe bytes
forall a. Maybe a
Nothing
  go (Just bytes
trailing) (DecodePartial Maybe bytes -> DecodeStep bytes failure a
k) = (DecodeStep bytes failure a -> m (DecodeStep bytes failure a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DecodeStep bytes failure a -> m (DecodeStep bytes failure a))
-> (Maybe bytes -> DecodeStep bytes failure a)
-> Maybe bytes
-> m (DecodeStep bytes failure a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe bytes -> DecodeStep bytes failure a
k) (bytes -> Maybe bytes
forall a. a -> Maybe a
Just bytes
trailing) m (DecodeStep bytes failure a)
-> (DecodeStep bytes failure a
    -> m (Either failure (a, Maybe bytes)))
-> m (Either failure (a, Maybe bytes))
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Maybe bytes
-> DecodeStep bytes failure a
-> m (Either failure (a, Maybe bytes))
go Maybe bytes
forall a. Maybe a
Nothing