Scalaの継続モナド(Continuationモナド)を理解する

Scala実践プログラミングに出てきた継続モナド(Continuationモナド)を理解するに手こずったので、備忘録として簡単に説明を残します。該当ページはp.304,305です。

Scala実践プログラミング―オープンソース徹底活用

以下は継続モナドの定義です。

class Cont[A, B](m: (A => B) => B) {
// for内包表記用のメソッド。mapとflatMap。
def map[C](f: A => C): Cont[C, B] =
new Cont(k => m(x => k(f(x))))

def flatMap[C](f: A => Cont[C, B]): Cont[C, B] =
new Cont(k => m(x => f(x).run(k)))

// モナドを実行するメソッド
def run(k: A => B): B = m(k)
}

この継続モナドを使用したサンプル(ほぼ本のまま)は以下の通りです。本ではwork.run(u => u)の部分がありませんが、この行がないとモナドの中身が実行されません。

object Main extends App {
type TClosable = { def close(): Unit }

def using[C <% TClosable, T](h: C)(work: C => T): T = {
try {
work(h)
} finally {
h.close()
}
}

def withFile(file: String): Cont[BufferedSource, Unit] = {
new Cont(k => {
using(Source.fromFile(file)) {r: BufferedSource => k(r)}
})
}

val work =
for {
foo <- withFile("foo.txt")
bar <- withFile("bar.txt")
} yield {
for(line <- foo.getLines.zip(bar.getLines)) println(line)
}

work.run(u => u)
}

Contクラス(継続モナド用クラス)では関数リテラルと型パラメータが多用されており、さらにContが再帰的に構築されるようになっているためかなり複雑で理解しづらいです。理解しなくても問題なく利用できますが、やっぱりちゃんとどういう仕組で動いているか理解したいですよね!

ということで以下では上記のサンプルがどういう仕組みになっているのかコードを追って確認したいと思います。

まずfor内包表記を変換規則に従って変換すると以下のコードになります。

val work = withFile("foo.txt").flatMap(foo =>
withFile("bar.txt").map(bar =>
for(line <- foo.getLines.zip(bar.getLines)) println(line)
))

ではコードを順に追っていきましょう。

まずwithFile(“foo.txt”)の部分でContのインスタンスが生成されます。
型推論で省略された型を明記すると、withFileはCont[BufferedSource, Unit]を返すことからnew Contした型はCont[BufferedSource, Unit]です。

new Cont(k => {
using(Source.fromFile(file)) {r: BufferedSource => k(r)}
})

class Cont[A, B](m: (A => B) => B)なのでAがBufferedSource、BがUnitとなります。
Contのコンストラクタパラメータであるmはややこしいですが、A型を引数にとりB型を返す関数を引数に取りB型を返す関数です(日本語の方がわけわからないですね^^;)
従って、以下のkはBufferedSource => Unitとなる関数です。k(r)の戻り値はUnitとなるため、mの定義に合致しています。
mは(BufferedSource => Unit) => Unitとなる関数で、実体は以下のコードになります。

k => {
using(Source.fromFile(file)) {r: BufferedSource => k(r)}
}

次にwithFile(“foo.txt”)で作成したContインスタンスに対してflatMapが呼ばれています。
flatMapの引数fに該当するのは以下のコードです。

foo =>
withFile("bar.txt").map(bar =>
for(line <- foo.getLines.zip(bar.getLines)) println(line)
)

flatMapを呼び出したのはCont[BufferedSource, Unit]なので、fはBufferedSource => Cont[C, Unit]と推測できます。

Cは何型でしょうか?コードを追っていくとfor(line <- foo.getLines.zip(bar.getLines)) println(line)の行の型がCとなりそうです。従ってCはUnitでfはBufferedSource => Cont[Unit, Unit]となります。

flatMapの中でも新たなContインスタンスが生成されています。

def flatMap[C](f: A => Cont[C, B]): Cont[C, B] =
new Cont(k => m(x => f(x).run(k)))

Contのコンストラクタパラメータとなるmの型から推論すると、kはC => B、xはAとなります。
新たに作られたContのmは関数で中身はk => m(x => f(x).run(k))です。右側のmはwithFile(“foo.txt”)で作られたContインスタンスのmです(中身はk => { using(Source.fromFile(file)) {r: BufferedSource => k(r)}})。

この時点では関数オブジェクト(mやf)が複数作られているだけで、中身は一切実行されていないことに注意して下さい。

workはflatMapで新たに作られた上記のCont[Unit, Unit]が代入されます。
そしていよいよモナドの実行です。

work.run(u => u)

runメソッドはk: A => B、この場合はUnit => Unitとなる関数を引数に取るので、何もしない関数を上記では渡しています。

runではm(k)が実行されます。
最初はworkに代入されたCont、つまりflatMapで作成されたContインスタンスフィールドのm関数(以降便宜上flatMap#mと呼びます)が、上記の何もしない関数kを引数に実行されます。
flatMap#mの中身は以下です。

k => m(x => f(x).run(k))

ここでようやく右側の関数定義部分が評価されます。
右側のmはwithFile(“foo.txt”)で作られたContインスタンスのm(以降便宜上foo#mと呼びます)であることに注意して下さい。

x => f(x).run(k)はfoo#mの引数となるのでBufferedSource => Unitという関数オブジェクトとなります。

関数本体に当たるf(x).run(k)の部分はこの時点では実行されません。

この関数を引数kとしてfoo#mが呼び出されます。
foo#mは以下の通りです。

k => {
using(Source.fromFile(file)) {r: BufferedSource => k(r)}
}

ここでようやく”foo.txt”部分のusingが呼ばれ、中でBufferedSourceを引数としてk(r)が呼ばれます。

kは上記のflatMapで定義された以下の関数オブジェクトです。

x => f(x).run(k)

ややこしいですが右側のkは最初に渡した何もしない関数です。xはr: BufferedSourceに相当し、ようやくここでf(x)が評価されます。
fはflatMapに渡された引数(便宜上flatMap#f)で以下のコードであることを思い出して下さい。

foo =>
withFile("bar.txt").map(bar =>
for(line <- foo.getLines.zip(bar.getLines)) println(line)
)

f(x)を評価するとさっきと同じようにwithFile(“bar.txt”)が実行されてさらに新たなCont[BufferedSource, Unit]が作られます。
その後今度はmapメソッドが実行されここでもflatMapと同様に新たなCont[Unit, Unit]が作られます。
mapの引数であるf(便宜上bar#f)の中身は以下のコードです。

bar => for(line <- foo.getLines.zip(bar.getLines)) println(line)

この後f(x).run(k)の評価に戻ります(このfはflatMap#f)。

f(x)の戻り値はCont[Unit, Unit]なので、このインスタンスに対してrun(k)が呼ばれます。

kは何もしない関数です。

run(k)を実行するとmapで作成したCont[Unit, Unit]のmが呼ばれます。従って、mは以下のコードです。

k => m(x => k(f(x)))

右側のmはwithFile(“bar.txt”)で作成したインスタンスのm(便宜上bar#m)です。

fはbar#fです。

さっきと同様x => k(f(x))という関数オブジェクトを引数にbar#mが呼ばれます。

bar#mは以下の通りです。

k => {
using(Source.fromFile(file)) {r: BufferedSource => k(r)}
}

kはx => k(f(x))です(右側のkは何もしない関数)。
fはbar#fです。

従って”bar.txt”のBufferedSourceを引数としてやっと以下のbar#fが呼ばれます。

bar => for(line <- foo.getLines.zip(bar.getLines)) println(line)

kは何もしない関数なので、k(f(x))を実行しても何も起こりません。

これでwork.run(u => u)の呼び出しも終了です。

ふぅ、疲れた。。。