WriterT 相关源码分析
看完ReaderT
的源码,接着来分析WriterT
的源码。
示例
先看一个超级简单的例子:
import Control.Monad
import Control.Monad.Trans.Writer
data LogEntry = LogEntry { msg :: String } deriving (Eq, Show)
calc :: Writer [LogEntry] Integer
calc = do
tell [LogEntry "Start"]
let x = sum [1 .. 10000000]
tell [LogEntry (show x)]
tell [LogEntry "done"]
return x
test = execWriter calc
test2 = runWriter calc
test
执行结果:
[LogEntry {msg = "Start"},LogEntry {msg = "50000005000000"},LogEntry {msg = "done"}]
test2
执行结果:
(50000005000000,[LogEntry {msg = "Start"},LogEntry {msg = "50000005000000"},LogEntry {msg = "done"}])
我们定义了一个LogEntry
类型用来简单模拟日志功能,在calc
方法里,多次调用tell
操作,就可以不断地增加LogEntry
数据 。调用test
,test2
就可以得到累积的结果。test2
得到一个(a,w)
类型的值,test
丢弃了a
,只保留w
。
Writer
是如何实现tell
这种操作的?我们从源码层面深入了解一下。
先从Writer
开始。
Writer
在Control.Monad.Trans.Writer.Strict
模块Control.Monad.Trans.Writer.Lazy
模块同时都定义了Writert
和WriterT
,两个模块绝大部分的代码都是相同的。主要区别在于,如果其中使用模式匹配的求值模式。例如:
(a,b) = ...
(a,b)
是否立即求值(Strict 模式),还是等a,b
真正使用的那一刻再求值(如果不使用就永远不求值,Lazy模式)。这可能会影响程序的性能,此外如果b
是一个无限队列,Strict
模块就不能正确处理了。
Control.Monad.Trans.Writer
模块使用的是Lazy
模块,所以我们只看Control.Monad.Trans.Writer.Lazy
模块下的相关定义。
-- ---------------------------------------------------------------------------
-- | A writer monad parameterized by the type @w@ of output to accumulate.
--
-- The 'return' function produces the output 'mempty', while @>>=@
-- combines the outputs of the subcomputations using 'mappend'.
type Writer w = WriterT w Identity
熟悉的套路,Writer
的定义之于WriterT
,就像Reader
之于ReaderT
。
Writer
只是一个别名,重点看WriterT
,在同一个源文件里。
WriterT
-- | A writer monad parameterized by:
--
-- * @w@ - the output to accumulate.
--
-- * @m@ - The inner monad.
--
-- The 'return' function produces the output 'mempty', while @>>=@
-- combines the outputs of the subcomputations using 'mappend'.
newtype WriterT w m a = WriterT { runWriterT :: m (a, w) }
deriving (Generic)
instance (Functor m) => Functor (WriterT w m) where
fmap f = mapWriterT $ fmap $ \ ~(a, w) -> (f a, w)
instance (Foldable f) => Foldable (WriterT w f) where
...
instance (Traversable f) => Traversable (WriterT w f) where
...
instance (Monoid w, Applicative m) => Applicative (WriterT w m) where
pure a = WriterT $ pure (a, mempty)
f <*> v = WriterT $ liftA2 k (runWriterT f) (runWriterT v)
where k ~(a, w) ~(b, w') = (a b, w `mappend` w')
instance (Monoid w, Alternative m) => Alternative (WriterT w m) where
empty = WriterT empty
m <|> n = WriterT $ runWriterT m <|> runWriterT n
instance (Monoid w, Monad m) => Monad (WriterT w m) where
return a = writer (a, mempty)
m >>= k = WriterT $ do
~(a, w) <- runWriterT m
~(b, w') <- runWriterT (k a)
return (b, w `mappend` w')
#if !(MIN_VERSION_base(4,13,0))
fail msg = WriterT $ fail msg
#endif
WriterT
的定义只是包装了一下m (a,w)
。定义前面的注释讲, m
是一个monad
,w
是用于累积的输出。
WriterT
实现了很多typeclass
比如Functor
, Foldable
, Traversable
, Applicative
, Alternative
, Monad
。
我们重点看Monad (WriterT w m)
的实现。自然就是return
以及>>=
操作的实现了。
fail msg = WriterT $ fail msg
比较直观不解释。
return 方法
return a = writer (a, mempty)
调用的writer
的定义:
-- | Construct a writer computation from a (result, output) pair.
-- (The inverse of 'runWriter'.)
writer :: (Monad m) => (a, w) -> WriterT w m a
writer = WriterT . return
于是WriterT
的return
方法:
return a = writer (a, mempty)
= WriterT . return (a,mempty)
= WriterT . m (a,mempty) -- Writer 的 return a = m a
也即是将 a
包装成了 WriterT mempty m a
。w
这个累积量为mempty
。
>>= 方法
继续看Monad (WriterT w m)
的>>=
方法实现:
m >>= k = WriterT $ do
~(a, w) <- runWriterT m
~(b, w') <- runWriterT (k a)
return (b, w `mappend` w')
~
关键字是Lazy pattern bindings
,当左边匹配值在后续被使用时,才对匹配的绑定值进行具体求值。
注意m >>= k
里的m
不是WriterT w m
里的m
,而是表示WriterT w m a
。(因为Monad (WriterT w m)
)。
这里的逻辑比较容易理解:
do
里面执行runWriterT m
得到(a,w)
。
再次执行runWriterT (k,a)
得到(b,w')
然后将result
值b
以及拼接的 w mappend w'
作为元组, return
成 m (b, w mappend w')
再使用WriterT
生成最终值(刚好符合定义的类型WriterT { runWriterT :: m (a, w) }
)。
从这里的操作过程可以更清楚w
为什么被称为the output to accumulate
。
mempty
与mappend
都是monoid
中定义的操作:
class Semigroup m => Monoid m where
mempty :: m
-- defining mappend is unnecessary, it copies from Semigroup
mappend :: m -> m -> m
mappend = (<>)
-- defining mconcat is optional, since it has the following default:
mconcat :: [m] -> m
mconcat = foldr mappend mempty
而<>
来自于Semigroup
。monoid
遵循的定律:
-- 单位元法则
-- Identity laws
x <> mempty = x
mempty <> x = x
-- 结合律
-- Associativity laws
(x <> y) <> z = x <> (y <> z)
MonadWriter
在看WriterT
的其他操作之前,先看一看MonadWriter
这个重要的typeclass
。MonadWriter
抽象定义了Monad
Writer
的行为。WriterT
是MonadWriter
众多instance
其中的一个。
在Control.Monad.Writer.Class
文件:
class (Monoid w, Monad m) => MonadWriter w m | m -> w where
-- | @'writer' (a,w)@ embeds a simple writer action.
writer :: (a,w) -> m a
writer ~(a, w) = do
tell w
return a
-- | @'tell' w@ is an action that produces the output @w@.
tell :: w -> m ()
tell w = writer ((),w)
-- | @'listen' m@ is an action that executes the action @m@ and adds
-- its output to the value of the computation.
listen :: m a -> m (a, w)
-- | @'pass' m@ is an action that executes the action @m@, which
-- returns a value and a function, and returns the value, applying
-- the function to the output.
pass :: m (a, w -> w) -> m a
instance (Monoid w, Monad m) => MonadWriter w (Lazy.WriterT w m) where
writer = Lazy.writer
tell = Lazy.tell
listen = Lazy.listen
pass = Lazy.pass
instance (Monoid w, Monad m) => MonadWriter w (Strict.WriterT w m) where
writer = Strict.writer
tell = Strict.tell
listen = Strict.listen
pass = Strict.pass
MonadWriter w m
定义了四个方法。writer
与tell
相互递归定义(Mutual recursion
),writer
操作接受一个(a,w)
类型参数,tell
接受一个w
类型的参数(w
是一个 monoid
)。两者都产生一个Writer action
。
实际上,如果一个值的类型是 Monad m => m a
,这个值就是一个action
,m
是Monad
的实例,a
是执行(do)action
的结果的类型。
而函数(function),是将输入参数与输出结果关联起来而构建的特定关系。Haskell
是纯函数式编程,这意味着,在runtime
阶段,函数表现的还是像编译阶段那样,只是一种(关联输入参数与输出的)结构,此外不会do
其他任何事情。
而action
,不需要输入参数,并且产生结果。在runtime
阶段,执行(例如 do { m a ; ... }
)之后,会对环境产生副作用。
为了体会一下两者的区别,看一下:
putStrLn::String -> IO ()
是一个函数。
putStrLn "hello world" :: IO ()
是一个 action
,执行这个 action
的话,会在屏幕输出"hello world",这个action
的结果类型是 ()
。
回来继续看,Lazy.WriterT w m
就是上面我们看过的WriterT w m
。
instance (Monoid w, Monad m) => MonadWriter w (Lazy.WriterT w m) where
writer = Lazy.writer
tell = Lazy.tell
listen = Lazy.listen
pass = Lazy.pass
接下来分析WriterT
实现MonadWriter
的四个操作(为什么称为操作,因为这类方法经常以monad
为参数,或者经常含有do
,会执行action
)。
Lazy.writer
就是上面分析return
时分析过的writer
操作,不再分析。
Lazy.tell
在WriterT
中的定义:
-- | @'tell' w@ is an action that produces the output @w@.
tell :: (Monad m) => w -> WriterT w m ()
tell w = writer ((), w)
我们知道writer
操作就是将(a,w)
构造成WriterT w m a
类型的,于是有tell
:
tell w = writer ((),w)
= WriterT . return ((), w)
= WriterT m ((),w)
= WriterT w m ()
Lazy.listen
在WriterT
中的定义,步骤很直观,无需解释:
-- | @'listen' m@ is an action that executes the action @m@ and adds its
-- output to the value of the computation.
--
-- * @'runWriterT' ('listen' m) = 'liftM' (\\ (a, w) -> ((a, w), w)) ('runWriterT' m)@
listen :: (Monad m) => WriterT w m a -> WriterT w m (a, w)
listen m = WriterT $ do
~(a, w) <- runWriterT m
return ((a, w), w)
Lazy.pass
在WriterT
中的定义:
-- | @'pass' m@ is an action that executes the action @m@, which returns
-- a value and a function, and returns the value, applying the function
-- to the output.
--
-- * @'runWriterT' ('pass' m) = 'liftM' (\\ ((a, f), w) -> (a, f w)) ('runWriterT' m)@
pass :: (Monad m) => WriterT w m (a, w -> w) -> WriterT w m a
pass m = WriterT $ do
~((a, f), w) <- runWriterT m
return (a, f w)
常用操作(operation)
-- | Extract the output from a writer computation.
--
-- * @'execWriterT' m = 'liftM' 'snd' ('runWriterT' m)@
execWriterT :: (Monad m) => WriterT w m a -> m w
execWriterT m = do
~(_, w) <- runWriterT m
return w
-- | Map both the return value and output of a computation using
-- the given function.
--
-- * @'runWriterT' ('mapWriterT' f m) = f ('runWriterT' m)@
mapWriterT :: (m (a, w) -> n (b, w')) -> WriterT w m a -> WriterT w' n b
mapWriterT f m = WriterT $ f (runWriterT m)
-- | @'listens' f m@ is an action that executes the action @m@ and adds
-- the result of applying @f@ to the output to the value of the computation.
--
-- * @'listens' f m = 'liftM' (id *** f) ('listen' m)@
--
-- * @'runWriterT' ('listens' f m) = 'liftM' (\\ (a, w) -> ((a, f w), w)) ('runWriterT' m)@
listens :: (Monad m) => (w -> b) -> WriterT w m a -> WriterT w m (a, b)
listens f m = WriterT $ do
~(a, w) <- runWriterT m
return ((a, f w), w)
实现都比较直观。
另外也看一看Writer
的其他常用操作:
-- | Unwrap a writer computation as a (result, output) pair.
-- (The inverse of 'writer'.)
runWriter :: Writer w a -> (a, w)
runWriter = runIdentity . runWriterT
-- | Extract the output from a writer computation.
--
-- * @'execWriter' m = 'snd' ('runWriter' m)@
execWriter :: Writer w a -> w
execWriter m = snd (runWriter m)
-- | Map both the return value and output of a computation using
-- the given function.
--
-- * @'runWriter' ('mapWriter' f m) = f ('runWriter' m)@
mapWriter :: ((a, w) -> (b, w')) -> Writer w a -> Writer w' b
mapWriter f = mapWriterT (Identity . f . runIdentity)
runWriter
与writer
是互逆操作。runWriter
的实现的详细解析过程如下:
runWriter (Writer w a) = runIndentity . runWriterT (Writer w a)
= runIndentity . runWriterT (WriterT w Indentity a)
-- runIndentity . m (a,w)
= runIndentity . Indentity (a,w)
= (a,w)
其他两个方法的实现也是类似推导。
再看示例
看完了源码,终于可以彻底分析一开始的示例了,对于关键操作:
calc :: Writer [LogEntry] Integer
calc = do
tell [LogEntry "Start"]
let x = sum [1 .. 10000000]
tell [LogEntry (show x)]
tell [LogEntry "done"]
return x
理解的第一个关键点是do
操作,在do
里面我们调用了三次tell
,表面上看,没有使用 a <- m a
这种方式调用tell
,而是丢弃了tell
的计算结果,那如何进行积累操作呢?让我们还原do
语法糖:
calc :: Writer [LogEntry] Integer
-- let 的声明转换有一点不精确,但是这个例子无影响
calc = let x = sum [1 .. 10000000] in
tell [LogEntry "Start"] >> tell [LogEntry (show x)] >> tell [LogEntry "done"]
return x
第二个关键点是>>
操作符的理解。从上面的源码我们知道,WriterT
以及Writer
都没有重新定义自己的>>
,那么这个>>
就是Monad
原始定义的:
(>>) :: m a -> m b -> m b
a >> f = a >>= \_ -> f
我们先只看tell [LogEntry "Start"] >> tell [LogEntry (show x)]
,对show x
求值为 "50000005000000"
,同时注意应用>>=
在WriterT
中的定义:
tell [LogEntry "Start"] >> tell [LogEntry "50000005000000"]
= tell [LogEntry "Start"] >>= \_ -> tell [LogEntry "50000005000000"]
-- 用writer替换
= writer((),[LogEntry "Start"]) >>= \_ -> writer((),[LogEntry "50000005000000"])
-- 接下来就是套用 >>= 的定义了
= WriterT $ do
~(a,w) <- runWriterT writer((),[LogEntry "Start"])
~(b,w') <- runWriterT (\_ -> writer((),[LogEntry "50000005000000"]) a)
return (b, w `mappend` w')
-- 操作符"<-"右边的可以计算, 另外 Writer 的 m 是 Indentity
= WriterT $ do
~(a,w) <- Identity ((),[LogEntry "Start"])
-- (\_ -> m) a = m ,这一步是应用函数到参数
~(b,w') <- Identity ((),[LogEntry "50000005000000"])
return (b, w `mappend` w')
= Writer $ do
-- a = (),b = (),w = [LogEntry "Start"], w'=[LogEntry "50000005000000"]
return ((),[LogEntry "Start"] `mappend` [LogEntry "50000005000000"])
= WriterT [LogEntry "Start",LogEntry "50000005000000"] Identity ()
= Writer [LogEntry "Start",LogEntry "50000005000000"] ()
再加上tell [LogEntry "done"]
就会是:
Writer [LogEntry "Start",LogEntry "50000005000000",LogEntry "done"] ()
至于execWriter
和runWriter
上面源码已经很明显,不需要再讲解。
参考
- https://hackage.haskell.org/package/mtl-2.2.2/docs/src/Control.Monad.Writer.Class.html#MonadWriter
- https://hackage.haskell.org/package/transformers-0.6.0.2/docs/src/Control.Monad.Trans.Writer.Lazy.html
- https://hackage.haskell.org/package/transformers-0.6.0.2/docs/src/Control.Monad.Trans.Writer.Strict.html
- https://kseo.github.io/posts/2017-01-21-writer-monad.html
- http://dev.stephendiehl.com/fun/basics.html
- https://mmhaskell.com/monads/reader-writer
- https://blog.ssanj.net/posts/2018-01-12-stacking-the-readert-writert-monad-transformer-stack-in-haskell.html