JubatusにThompson samplingを実装する

Jubatus0.7.0についにBanditアルゴリズムが実装されたのですが、漸近最適なアルゴリズムがまだ実装されていないので、Thompson sampling (TS) を実装してみました。

TSの詳細はThompson sampling - Wikipedia, the free encyclopediaなどに詳しいです。TSはThompsonさんが1930年に提案された最も古いアルゴリズムの1つなのですが、バンディット業界ではUCBなどと比べるとほとんど知られていませんでした。Googleの中の人がABテストに利用したことや、NIPS2011でTSの性能を他のアルゴリズムと比較した論文が出版され、圧倒的に性能が良いことが示されたことでTSは一躍注目をされるようになりました。理論的にも、TSは漸近最適なアルゴリズム*1の1つとして知られています。このアルゴリズムベイズ推定に基づくためその実装が報酬の確率分布に依存するのですが、今回は報酬が{0,1}の場合を考えてみましょう。この場合の共役事前分布はBeta分布なので、Beta分布を実装すればTSアルゴリズムを作ることができます。

というわけで、TSを実装してGithubに置いてみました*2

https://github.com/jkomiyama/jubatus_core/commit/ce2abe22909473a557af957064e2e5c220d361b5

性能を比較してみましょう。
サーバスクリプト: ts.json

{
  "method": "ts",
    "parameter": {
  }
}

で、Banditサーバを起動します。

jubabandit -f ts.json

クライアントでbanditのシミュレーションを回してみます。
クライアントスクリプト: test.py

#!/usr/bin/env python
# coding: utf-8

host = '127.0.0.1'
port = 9199
name = 'test'

import sys
import json
import random
from math import *

import jubatus

def kl(p,q):return p*log(p/q)+(1-p)*log((1-p)/(1-q))

def run(client):
  arms = {
    'best':0.1,
    'soso':0.05,
    'bad':0.02
  }
  player = 'bandit'
  bestAvg = arms['best']
  counts = {}
  for arm in arms.keys():counts[arm]=0

  client.reset(player)
  regret = 0
  for arm in arms.keys():
    client.register_arm(arm)
  for t in xrange(10000):
    arm = client.select_arm(player) 
    if random.random() <= arms[arm]:
      reward = 1.0
    else:
      reward = 0.0
    counts[arm]+=1
    regret += bestAvg - arms[arm]
    client.register_reward(player, arm, reward)
  print "regret =",regret 
  #regret asymptotic lower bound
  lower = sum([log(10000)*(bestAvg-arms[arm])/kl(arms[arm],bestAvg) for arm in arms.keys() if arms[arm] != bestAvg])
  print "lower =",lower

if __name__ == '__main__':
    client = jubatus.Bandit(host, port, name)
    run(client)

サーバがUCB1の場合*3と、今回実装したTSの場合の結果比較:
f:id:jkomi:20150301150125p:plain

約9倍の性能向上(損失減少)ができました*4

*1:どんなアームの集合に関しても上手く動くアルゴリズムの中で、理論的な性能が最も良いもの

*2:jubatusにpull requestを出したいのですが、著作権周りがよくわからないので保留中・・・

*3:Jubatusにすでに実装されている別のバンディットアルゴリズム、これも有名

*4:バンディットは確率的な問題のため、ちゃんと性能比較したい場合、もっとたくさんの回数のシミュレーションを行って期待値を取る必要がありますが、今回は簡単のため1回しかやっていません