合成できるモナド、モナドが合成できる時

(今回のコードは github:everpeace/composing-monads にあります。)
一般に、モナドって合成できないって言われますよね。

でも、モナドって合成できる場合も有るんです。

今回はまず、合成が難しい(できない)理由を説明して、

じゃぁ「モナドが合成できる時」はいつなのか?

というのと「合成できるモナドたち」をちょっとだけ紹介してみます。

それにはまず、ここから始めましょう。

モナドをflatMapじゃなくてflattenで定義してみる。

flatMapで定義されるモナドのおさらい

Mというモナドは次の用に定義されます。

trait Functor[F[_]]{
    def map[A,B](fa:F[A])(f:(A) => B):F[B]
}
trait Monad[M[_]] extends Functor[M]{
    def unit[A](a:A):M[A]
    def flatMap[A,B](ma: M[A])(f:A => M[B]):M[B]
}

有名なリストモナド(たくさんの値を返す計算の抽象です)はこんな風に定義されます。

implicit val ListMonad = new Monad[List]{
    def map[A,B](as:List[A])(f:(A)=>B) = as.map(f)
    def unit[A](a:A):List[A] = List(a)
    def flatMap[A,B](as:List[A])(f:A => List[B]) = (List[B]()/:as)(_++f(_))
}

リストモナドのflatMapが何をやってるかおさらいしましょう。

def repeat2(a:Int):List[Int] = List(a,a)
ListMonad.flatMap(List(1,2,3))(repeat2) //=== List(1,1,2,2,3,3)

分解して説明すると、List(1,2,3)の各要素にrepeat2を施すとこんな感じ。

1 => List(1,1)
2 => List(2,2)
3 => List(3,3)

これを、List[List[Int]]じゃなくて、それをflatにしたList[Int]にして返すのがflatMapです。名前の通りなのですが、ただmapするだけだと、List[List[_]]になるので、それをflatする。

モナドのパワーがここにあるのは様々なブログに書かれている通りです。

flatMapじゃなくてflattenでMonadを定義してみる。

mapは与えられるのだから、モナドのパワーの根源はflatする部分にあると感じませんか?圏論の小難しい事は抜きにしますが、圏論におけるモナドは元々, unit, flattenの2つで定義されているのです!

コードで書くとこんな具合です。

trait Monad[M[_]] extends Functor[M]{
    def unit[A](a:A):M[A]
    def flatten[A](mm:M[M[A]]):M[A]
}

上と同じ様にListモナドをflattenで定義してみましょう。

implicit val ListMonad = new Monad[List]{
    def map[A, B](as: List[A])(f: (A) => B) = as.map(f)
    def unit[A](a: A) = List(a)
    def flatten[A](ass: List[List[A]]) =  (List[A]()/:ass)(_++_)
  }

こうですね。

さて、ちょっと抽象の世界に戻って考えてみましょう。flattenは2重になったモナドを1重にしてくれますよね。flatMapに渡されるfをmapに渡すとM[M[A]]になるので、実は、flatMapはこのflattenを使って抽象レベルで定義できるんです。

trait Monad[M] extends Functor[M]{
    def unit[A](a:A):M[A]
    def flatten[A](mma:M[M[A]]):M[A]
    def flatMap[A,B](f:A=>M[B])(ma:M[A]) = flatten(fmap(f,ma))
}

さっきのList Monadの例に戻ってみましょう。

implicit val ListMonad = new Monad[List]{
    def map[A, B](as: List[A])(f: (A) => B) = as.map(f)
    def unit[A](a:A):List[A] = List(a)
    def flatten[A](ass:List[List[A]]):List[A] = (List[A]() /: ass)(_++_)
    def flatMap[A,B](f:A=>List[B])(ma:List[A]):List[A] = flatten(fmap(f,ma))
}

このflatmapを分解して実行してみると、

ListMonad.flatMap(List(1,2,3)(repeat2)) 
= flatten(map(List(1,2,3))(repeat2))
= flatten(List(List(1,1),List(2,2),List(3,3))
= List(1,1,2,2,3,3)

flattenを使って、flatMapが元のと同じように動作しているのが分かりますね。(fの実行タイミングが違うのは置いておいて)

モナドの合成の難しさ

さて、flattenを通してモナドを理解してみると、モナド合成の難しさが見えやすくなります。

二つのモナドM,Nを考えましょう。モナドを合成するということは、M,Nという二つのモナドからMN[A]というモナドを作り出す事です。つまり、Monad[M], Monad[N]のインスタンスが与えられたとき、Monad[({type MN[α] = M[N[α]]})#MN]を与えよという事になります。

new Monad[({type MN[α] = M[N[α]]})#MN]{
      def map[A,B](ma:M[N[A]])(f: (A) => B) = /* これを実装する */
      def unit[A](a: A) =  /* これを実装する */
      def flatten[A](mnmn:M[N[M[N[A]]]])= /* これを実装する */
 }

簡単のために、M,Nで定義されている関数をこんな風に最後にM,Nをつけて区別することにします。

mapM, unitM, flattenM
mapN, unitN, flattenN

ここから、作り出したいのは

mapMN[A,B](mna: M[N[A])(f: A=>B): M[N[B]]
unitMN[A](a:A): M[N[A]]
flattenMN[A](mna: M[N[M[N[A]]]]):M[N[A]]

です。

unitMNは簡単そうですよね。

unitMN[A](a:A):M[N[A]] = unitM(unitN(_))

でオッケーそうです。

mapMNもなんとか行けそうですね。

// Mの中をmapN(f)するイメージ
mapMN[A,B](mna:M[N[A])(f: A=>B) = mapM(ma)(mapN(_)(f))

では、最後にflattenMNを考えてみましょう。

flatten[A](mnmn: M[N[M[N[_]]]]):M[N[A]]

を与えるにはどうすればいいでしょう?

今与えられている、mapM, unitM, flattenM, mapN, unitN, flattenN, mapMN, unitMNを組み合わせるだけでは、型からしてそのような型の関数を与える事はできないのです。

ここが、一般に二つのモナドが合成が無理と言われる所なのです。(flatMapを合成するよりflattenから見た方が簡単ですよね??)

モナドが合成できる時 (スワップできるモナドたち)

目指すは、flattenMNです。この関数の型に注目してみましょう。

M[N[M[N[A]]]] => M[N[A]]

です。flattenM,flattenNだけではにっちもさっちもいかなかったんですが、この引数の型の真ん中のNMをMNにひっくり返すようなことができたらどでしょう?

M[N[M[N[A]]]] 
==(真ん中のNMを反転)==> M[M[N[N[A]]]] 
==(flattenN)==> M[M[N[A]]] 
==(flattenM)==> M[N[A]]

とできるではないですか!

そんなにうまく行くはずがない。と思うでしょう。それができるんです。

swapというNMをMNにひっくり返すような関数が与えられているとしましょう。

swap[A](nma:N[M[A]):M[N[A]]

ですね。

このswapをmapMに渡すと何が起こるでしょうか?

//mapM(mnmn)(swap(_))の型は?
swap: N[M[_]] => M[N[_]]
mapM:  (M[X])(X=>Y) => M[Y]
//X=N[M[N[A]]], Y = M[N[N[A]]なのでM[Y] = M[M[N[N[A]]]]
mapM(mnmn)(swap(_)): M[M[N[N[A]]]]

です!なんと、MNMNの真ん中だけが見事にひっくり返ってMMNNになっています!

こうなれば、最初の戦略にあったように、flattenN, flattenMを順番にかければ、flattenMNの完成です!!!(実際は、ただ順番にかけるのではなく、mapに渡してやる必要はありますが)

flattenMN(mnmn:M[N[M[N[A]]]]):M[N[A]] 
= mapM(flattenM(mapM(mnmn)(swap(_))))(flattenN(_))

です。

つまり、もし、モナドM,Nについて、

swap: N[M[_]] => M[N[_]]

というモナドをひっくり返すswapper(圏論ではこれはdistributive lawと呼ばれます)が与えられているときには合成できる。(モナドMonad Lawを満たさなくてはならないように、本当はswapが満たすべきLawがいくつも有ります。)

ということなんです。

合成できるモナド

swapが満たすべき条件については今回は触れませんが、合成できるモナドについて書いてみます。

任意のMonad MとOption

任意のモナド MとOptionについて M[Option[_] ]というのは合成できます。こんな風にswapを与える事ができます。

def swap[A](oma: Option[M[A]]) = oma match{
      case Some(ma) => monadM.map(ma)(Option(_))
      case None => monadM.unit(None)
}
任意のMonad MとList

任意のモナド MとListについて M[List[_] ]というのは合成できます。こんな風にswapを与える事ができます。

def swap[A](nm: List[M[A]]):M[List[A]] = nm match {
    case List() => monadM.unit(List())
    case x::xs => for { y <- x;
                                  ys <- swap(xs)} yield y::ys
}
任意のMonad MとValidation

こんなValidationモナド

trait Validation[A]
case class Ok[A](a: A) extends Validation[A]
case class Error[A](msg: String) extends Validation[A]

object Validation {
  def ok[A](a:A):Validation[A] = Ok(a)
  def error[A](msg:String):Validation[A]=Error[A](msg)

  implicit val ValidationMonad = new Monad[Validation] {
    def unit[A](a: A) = Ok(a)
    def map[A, B](ma: Validation[A])(f: (A) => B) = ma match {
      case Ok(a) => Ok(f(a))
      case Error(msg) => Error(msg)
    }
    def flatten[A](mm: Validation[Validation[A]]) = mm match {
      case Ok(a) => a
      case Error(msg) => Error(msg)
    }
  }
}

と任意のモナド Mについて M[Validation[_] ]というもは合成できます。こんな風にswapを与える事ができます。

def swap[A](nm: Validation[M[A]]) = nm match{
    case Ok(a) => monadM.map(a)(Ok(_))
    case Error(msg) => monadM.unit(Error(msg))
}

こんな風にできます。

こんな風(github)に定義してやると、scala のfor comprehensionを活用してこんな風にできます。

// 普通にList[Option[_]]をfor文で束縛するとOptionが束縛される。
// i はSome(1)
for { i<- List(Option(1))         } yield {println(i);i}

// でも、composeするとfor文で束縛されるのiは1。
for { i<- compose(List(Option(1)))} yield {println(i);i}

// List[Option[_]] monad の計算が一度にできます。
// this outputs: List(Some(4), Some(5), Some(8), Some(9))
println(for { i<- compose(List(Option(1), None, Option(5)))
                  j<- compose(List(Option(3), Option(4)))     } yield i+j)

// Option[List[_]]の計算も出来ます。
// this outputs: Some(List(4, 5, 8, 9))
  println(for { i<- compose(Option(List(1, 5)))
                     j<- compose(Option(List(3, 4)))     } yield i+j)


// List[Validation[_]]の計算もいっぺんに出来ます。
// this outputs: List(Ok(3), Error(error2), Error(error3), 
//                            Error(error1), Ok(4), Error(error2), Error(error3))
  println(for { i <- compose(List(ok(1),error[Int]("error1"),Ok(2)));
                     j <- compose(List(ok(2),error[Int]("error2"),error[Int]("error3")))} yield i+j )

合わせて読みたい。

今回のコード

今回のコードはgithub:everpeace/composing-monadsに有ります。