Skip to content

zongwu's blog

WriterT 相关源码分析

看完ReaderT源码,接着来分析WriterT的源码。

示例

先看一个超级简单的例子:

import Control.Monad
import Control.Monad.Trans.Writer

data LogEntry = LogEntry { msg :: String } deriving (Eq, Show)

calc :: Writer [LogEntry] Integer
calc = do
    tell [LogEntry "Start"]
    let x = sum [1 .. 10000000]
    tell [LogEntry (show x)]
    tell [LogEntry "done"]
    return x
    
test = execWriter calc
test2 = runWriter calc

test执行结果:

[LogEntry {msg = "Start"},LogEntry {msg = "50000005000000"},LogEntry {msg = "done"}]

test2执行结果:

(50000005000000,[LogEntry {msg = "Start"},LogEntry {msg = "50000005000000"},LogEntry {msg = "done"}])

我们定义了一个LogEntry类型用来简单模拟日志功能,在calc方法里,多次调用tell操作,就可以不断地增加LogEntry数据 。调用test,test2就可以得到累积的结果。test2得到一个(a,w)类型的值,test丢弃了a,只保留w

Writer是如何实现tell这种操作的?我们从源码层面深入了解一下。

先从Writer开始。

Writer

Control.Monad.Trans.Writer.Strict 模块Control.Monad.Trans.Writer.Lazy 模块同时都定义了WritertWriterT,两个模块绝大部分的代码都是相同的。主要区别在于,如果其中使用模式匹配的求值模式。例如:

(a,b) =  ...

(a,b)是否立即求值(Strict 模式),还是等a,b真正使用的那一刻再求值(如果不使用就永远不求值,Lazy模式)。这可能会影响程序的性能,此外如果b是一个无限队列,Strict 模块就不能正确处理了。

Control.Monad.Trans.Writer模块使用的是Lazy模块,所以我们只看Control.Monad.Trans.Writer.Lazy 模块下的相关定义。

-- ---------------------------------------------------------------------------
-- | A writer monad parameterized by the type @w@ of output to accumulate.
--
-- The 'return' function produces the output 'mempty', while @>>=@
-- combines the outputs of the subcomputations using 'mappend'.
type Writer w = WriterT w Identity

熟悉的套路,Writer的定义之于WriterT,就像Reader之于ReaderT

Writer只是一个别名,重点看WriterT,在同一个源文件里。

WriterT

-- | A writer monad parameterized by:
--
--   * @w@ - the output to accumulate.
--
--   * @m@ - The inner monad.
--
-- The 'return' function produces the output 'mempty', while @>>=@
-- combines the outputs of the subcomputations using 'mappend'.
newtype WriterT w m a = WriterT { runWriterT :: m (a, w) }
    deriving (Generic)
    
instance (Functor m) => Functor (WriterT w m) where
    fmap f = mapWriterT $ fmap $ \ ~(a, w) -> (f a, w)

instance (Foldable f) => Foldable (WriterT w f) where
    ...

instance (Traversable f) => Traversable (WriterT w f) where
		...

instance (Monoid w, Applicative m) => Applicative (WriterT w m) where
    pure a  = WriterT $ pure (a, mempty)
    f <*> v = WriterT $ liftA2 k (runWriterT f) (runWriterT v)
      where k ~(a, w) ~(b, w') = (a b, w `mappend` w')
  
instance (Monoid w, Alternative m) => Alternative (WriterT w m) where
    empty   = WriterT empty
    m <|> n = WriterT $ runWriterT m <|> runWriterT n

instance (Monoid w, Monad m) => Monad (WriterT w m) where
    
    return a = writer (a, mempty)

    m >>= k  = WriterT $ do
        ~(a, w)  <- runWriterT m
        ~(b, w') <- runWriterT (k a)
        return (b, w `mappend` w')
     
#if !(MIN_VERSION_base(4,13,0))
    fail msg = WriterT $ fail msg
#endif

WriterT的定义只是包装了一下m (a,w)。定义前面的注释讲, m是一个monadw是用于累积的输出。

WriterT实现了很多typeclass 比如Functor, Foldable, Traversable, Applicative, Alternative, Monad

我们重点看Monad (WriterT w m)的实现。自然就是return以及>>=操作的实现了。

fail msg = WriterT $ fail msg 比较直观不解释。

return 方法

 return a = writer (a, mempty)

调用的writer的定义:

-- | Construct a writer computation from a (result, output) pair.
-- (The inverse of 'runWriter'.)
writer :: (Monad m) => (a, w) -> WriterT w m a
writer = WriterT . return

于是WriterTreturn方法:

return a = writer (a, mempty)
				 = WriterT . return (a,mempty) 
				 = WriterT . m (a,mempty)  -- Writer 的 return a = m a

也即是将 a 包装成了 WriterT mempty m aw这个累积量为mempty

>>= 方法

继续看Monad (WriterT w m)>>=方法实现:

    m >>= k  = WriterT $ do
        ~(a, w)  <- runWriterT m
        ~(b, w') <- runWriterT (k a)
        return (b, w `mappend` w')

~ 关键字是Lazy pattern bindings,当左边匹配值在后续被使用时,才对匹配的绑定值进行具体求值。

注意m >>= k 里的m不是WriterT w m里的m,而是表示WriterT w m a。(因为Monad (WriterT w m))。

这里的逻辑比较容易理解:

do里面执行runWriterT m 得到(a,w)

再次执行runWriterT (k,a)得到(b,w')

然后将resultb以及拼接的 w mappend w'作为元组, returnm (b, w mappend w')

再使用WriterT生成最终值(刚好符合定义的类型WriterT { runWriterT :: m (a, w) })。

从这里的操作过程可以更清楚w为什么被称为the output to accumulate

memptymappend都是monoid中定义的操作:

class Semigroup m => Monoid m where
  mempty :: m

  -- defining mappend is unnecessary, it copies from Semigroup
  mappend :: m -> m -> m
  mappend = (<>)

  -- defining mconcat is optional, since it has the following default:
  mconcat :: [m] -> m
  mconcat = foldr mappend mempty

<>来自于Semigroupmonoid遵循的定律:

-- 单位元法则
-- Identity laws
x <> mempty = x
mempty <> x = x

-- 结合律
-- Associativity laws
(x <> y) <> z = x <> (y <> z)

MonadWriter

在看WriterT的其他操作之前,先看一看MonadWriter这个重要的typeclassMonadWriter抽象定义了Monad Writer的行为。WriterTMonadWriter众多instance其中的一个。

Control.Monad.Writer.Class 文件:

class (Monoid w, Monad m) => MonadWriter w m | m -> w where
    -- | @'writer' (a,w)@ embeds a simple writer action.
    writer :: (a,w) -> m a
    writer ~(a, w) = do
      tell w
      return a

    -- | @'tell' w@ is an action that produces the output @w@.
    tell   :: w -> m ()
    tell w = writer ((),w)

    -- | @'listen' m@ is an action that executes the action @m@ and adds
    -- its output to the value of the computation.
    listen :: m a -> m (a, w)
    -- | @'pass' m@ is an action that executes the action @m@, which
    -- returns a value and a function, and returns the value, applying
    -- the function to the output.
    pass   :: m (a, w -> w) -> m a
    
    
instance (Monoid w, Monad m) => MonadWriter w (Lazy.WriterT w m) where
    writer = Lazy.writer
    tell   = Lazy.tell
    listen = Lazy.listen
    pass   = Lazy.pass

instance (Monoid w, Monad m) => MonadWriter w (Strict.WriterT w m) where
    writer = Strict.writer
    tell   = Strict.tell
    listen = Strict.listen
    pass   = Strict.pass 

MonadWriter w m定义了四个方法。writertell相互递归定义(Mutual recursion),writer操作接受一个(a,w)类型参数,tell接受一个w类型的参数(w是一个 monoid)。两者都产生一个Writer action

实际上,如果一个值的类型是 Monad m => m a,这个值就是一个actionmMonad的实例,a是执行(do)action的结果的类型。

而函数(function),是将输入参数与输出结果关联起来而构建的特定关系。Haskell是纯函数式编程,这意味着,在runtime阶段,函数表现的还是像编译阶段那样,只是一种(关联输入参数与输出的)结构,此外不会do其他任何事情。

action,不需要输入参数,并且产生结果。在runtime阶段,执行(例如 do { m a ; ... })之后,会对环境产生副作用。

为了体会一下两者的区别,看一下:

putStrLn::String -> IO ()

是一个函数。

putStrLn "hello world" :: IO ()

是一个 action,执行这个 action的话,会在屏幕输出"hello world",这个action的结果类型是 ()

回来继续看,Lazy.WriterT w m就是上面我们看过的WriterT w m

instance (Monoid w, Monad m) => MonadWriter w (Lazy.WriterT w m) where
    writer = Lazy.writer
    tell   = Lazy.tell
    listen = Lazy.listen
    pass   = Lazy.pass

接下来分析WriterT 实现MonadWriter的四个操作(为什么称为操作,因为这类方法经常以monad为参数,或者经常含有do,会执行action)。

Lazy.writer就是上面分析return时分析过的writer操作,不再分析。

Lazy.tellWriterT中的定义:

-- | @'tell' w@ is an action that produces the output @w@.
tell :: (Monad m) => w -> WriterT w m ()
tell w = writer ((), w)

我们知道writer操作就是将(a,w)构造成WriterT w m a 类型的,于是有tell

tell w  = writer ((),w)
	= WriterT . return ((), w)
	= WriterT m ((),w)
	= WriterT w m ()

Lazy.listenWriterT中的定义,步骤很直观,无需解释:

-- | @'listen' m@ is an action that executes the action @m@ and adds its
-- output to the value of the computation.
--
-- * @'runWriterT' ('listen' m) = 'liftM' (\\ (a, w) -> ((a, w), w)) ('runWriterT' m)@
listen :: (Monad m) => WriterT w m a -> WriterT w m (a, w)
listen m = WriterT $ do
    ~(a, w) <- runWriterT m
    return ((a, w), w)    

Lazy.passWriterT中的定义:

-- | @'pass' m@ is an action that executes the action @m@, which returns
-- a value and a function, and returns the value, applying the function
-- to the output.
--
-- * @'runWriterT' ('pass' m) = 'liftM' (\\ ((a, f), w) -> (a, f w)) ('runWriterT' m)@
pass :: (Monad m) => WriterT w m (a, w -> w) -> WriterT w m a
pass m = WriterT $ do
    ~((a, f), w) <- runWriterT m
    return (a, f w)

常用操作(operation)

-- | Extract the output from a writer computation.
--
-- * @'execWriterT' m = 'liftM' 'snd' ('runWriterT' m)@
execWriterT :: (Monad m) => WriterT w m a -> m w
execWriterT m = do
    ~(_, w) <- runWriterT m
    return w
    
-- | Map both the return value and output of a computation using
-- the given function.
--
-- * @'runWriterT' ('mapWriterT' f m) = f ('runWriterT' m)@
mapWriterT :: (m (a, w) -> n (b, w')) -> WriterT w m a -> WriterT w' n b
mapWriterT f m = WriterT $ f (runWriterT m)


-- | @'listens' f m@ is an action that executes the action @m@ and adds
-- the result of applying @f@ to the output to the value of the computation.
--
-- * @'listens' f m = 'liftM' (id *** f) ('listen' m)@
--
-- * @'runWriterT' ('listens' f m) = 'liftM' (\\ (a, w) -> ((a, f w), w)) ('runWriterT' m)@
listens :: (Monad m) => (w -> b) -> WriterT w m a -> WriterT w m (a, b)
listens f m = WriterT $ do
    ~(a, w) <- runWriterT m
    return ((a, f w), w)
    

实现都比较直观。

另外也看一看Writer的其他常用操作:

-- | Unwrap a writer computation as a (result, output) pair.
-- (The inverse of 'writer'.)
runWriter :: Writer w a -> (a, w)
runWriter = runIdentity . runWriterT

-- | Extract the output from a writer computation.
--
-- * @'execWriter' m = 'snd' ('runWriter' m)@
execWriter :: Writer w a -> w
execWriter m = snd (runWriter m)

-- | Map both the return value and output of a computation using
-- the given function.
--
-- * @'runWriter' ('mapWriter' f m) = f ('runWriter' m)@
mapWriter :: ((a, w) -> (b, w')) -> Writer w a -> Writer w' b
mapWriter f = mapWriterT (Identity . f . runIdentity)

runWriterwriter是互逆操作。runWriter的实现的详细解析过程如下:

runWriter (Writer w a)  = runIndentity . runWriterT (Writer w a)
			= runIndentity . runWriterT (WriterT w Indentity a)
			-- runIndentity . m (a,w)
			= runIndentity . Indentity (a,w) 
			= (a,w)

其他两个方法的实现也是类似推导。

再看示例

看完了源码,终于可以彻底分析一开始的示例了,对于关键操作:

calc :: Writer [LogEntry] Integer
calc = do
    tell [LogEntry "Start"]
    let x = sum [1 .. 10000000]
    tell [LogEntry (show x)]
    tell [LogEntry "done"]
    return x

理解的第一个关键点是do操作,在do里面我们调用了三次tell,表面上看,没有使用 a <- m a这种方式调用tell,而是丢弃了tell的计算结果,那如何进行积累操作呢?让我们还原do语法糖:

calc :: Writer [LogEntry] Integer
-- let 的声明转换有一点不精确,但是这个例子无影响
calc = let x = sum [1 .. 10000000] in
		tell [LogEntry "Start"] >>  tell [LogEntry (show x)] >> tell [LogEntry "done"]
		return x

第二个关键点是>> 操作符的理解。从上面的源码我们知道,WriterT 以及Writer都没有重新定义自己的>>,那么这个>>就是Monad原始定义的:

(>>) :: m a -> m b -> m b
a >> f = a >>= \_ -> f

我们先只看tell [LogEntry "Start"] >> tell [LogEntry (show x)] ,对show x求值为 "50000005000000" ,同时注意应用>>=WriterT中的定义:

tell [LogEntry "Start"] >> tell [LogEntry "50000005000000"] 
= tell [LogEntry "Start"] >>=  \_ -> tell [LogEntry "50000005000000"]
-- 用writer替换
= writer((),[LogEntry "Start"]) >>= \_ -> writer((),[LogEntry "50000005000000"])
-- 接下来就是套用 >>= 的定义了
= WriterT $  do
  ~(a,w)  <- runWriterT writer((),[LogEntry "Start"])
  ~(b,w') <- runWriterT (\_ -> writer((),[LogEntry "50000005000000"]) a)
  return (b, w `mappend` w')
-- 操作符"<-"右边的可以计算, 另外 Writer 的 m 是 Indentity
= WriterT $  do
  ~(a,w)  <- Identity ((),[LogEntry "Start"])
  -- (\_ -> m) a = m ,这一步是应用函数到参数
  ~(b,w') <- Identity ((),[LogEntry "50000005000000"])
  return (b, w `mappend` w')
= Writer $ do
  -- a = (),b = (),w = [LogEntry "Start"], w'=[LogEntry "50000005000000"]
  return ((),[LogEntry "Start"] `mappend` [LogEntry "50000005000000"])
= WriterT [LogEntry "Start",LogEntry "50000005000000"] Identity ()
= Writer [LogEntry "Start",LogEntry "50000005000000"] ()

再加上tell [LogEntry "done"] 就会是:

Writer [LogEntry "Start",LogEntry "50000005000000",LogEntry "done"] ()

至于execWriterrunWriter上面源码已经很明显,不需要再讲解。

参考

  1. https://hackage.haskell.org/package/mtl-2.2.2/docs/src/Control.Monad.Writer.Class.html#MonadWriter
  2. https://hackage.haskell.org/package/transformers-0.6.0.2/docs/src/Control.Monad.Trans.Writer.Lazy.html
  3. https://hackage.haskell.org/package/transformers-0.6.0.2/docs/src/Control.Monad.Trans.Writer.Strict.html
  4. https://kseo.github.io/posts/2017-01-21-writer-monad.html
  5. http://dev.stephendiehl.com/fun/basics.html
  6. https://mmhaskell.com/monads/reader-writer
  7. https://blog.ssanj.net/posts/2018-01-12-stacking-the-readert-writert-monad-transformer-stack-in-haskell.html