pythonでクロージャを渡して順列生成

組み合わせを解く問題をやるときに毎回コーディングするのは面倒だ。一般的なコード書こう。

N個のシーケンスseqの中からk個を選び出す順列を生成する。

# permutation.py
def default_term(x):
    return [x]
def default_concat(x, y):
    return [x] + y

def kperm(seq, k, terminal_procedure=default_term, concat_procedure=default_concat):
    # seqからk個取り出す順列を生成する
    if k==1:
        for item in seq:
            yield terminal_procedure(item)
    else:
        for i in xrange(len(seq)):
            for p in kperm([seq[idx] for idx in xrange(len(seq)) if i!=idx],
                            k-1,
                            terminal_procedure,
                            concat_procedure):
                yield concat_procedure(seq[i], p)

def perm(seq, tp=default_term, cp=default_concat):
    # 全ての組み合わせを列挙
    return kperm(seq, len(seq), tp, cp)

上記の関数に組み合わせの操作が終了したときに実行される終端手続きterminal_procedureと、操作の途中での結果の結合に使う結合手続きconcat_procedureには関数オブジェクトを渡す。

このやりかたのいいところは、用途に応じて手続きを変えられるところ。たとえば、1から9のそれぞれの数字を1回だけ使った数字を全て列挙するなら

def honor(p):
    # ホーナー法
    return reduce(lambda x,y: x*10+y, p)

for p in perm(range(1,10)):
    honor(p)

こういうコードでもちゃんと動作する。けれど組み合わせが大きくなると激しく計算コストがかかる。これはデフォルトの終端手続きと結合手続きがリストの足し算を行っているためで、新しいリストオブジェクトが毎回新しく作られるため。より効率よく上の問題を解くなら、終端手続きと結合手続きを自前で定義して渡せばよい。リストを毎回生成するのではなく、組み合わせを生成しながらホーナー法で数字を構築する。

for n in perm(range(1,10), lambda x:x, lambda x,y: y*10+x):
    n

このコードは上に比べて高速に動作する。

$ python
>>> import permutation
>>> import time
>>> def honor(p):
...   return reduce(lambda x,y: x*10+y, p)
...
>>> def measure(proc):
...   s = time.time()
...   proc()
...   e = time.time()
...   return e-s
...
>>> def honor(p): return reduce(lambda x,y: x*10+y, p)
>>> measure(lambda: [honor(p) for p in permutation.perm(range(1,10))])
26.201292991638184
>>> measure(lambda: [n for n in permutation.perm(range(1,10), lambda x: x, lambda x,y: y*10, x))
16.391520023345947


これで単純な組み合わせのコードより柔軟さが増した。

追記:id:unauさんの指摘で最後のほうに間違いがあったので、修正してからdellのinspiron mini 9で実行しなおしました。id:unauさん、ありがとうございます!

Python 2.6.4 (r264:75708, Oct 26 2009, 08:23:19) [MSC v.1500 32 bit (Intel)] on
win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import permutation
>>> import time
>>> def honor(p): return reduce(lambda x,y:x*10+y, p)
...
>>> def measure(proc):
...  s = time.time()
...  proc()
...  e = time.time()
...  return e-s
...
>>> measure(lambda: [honor(p) for p in permutation.perm(range(1,10))])
20.092999935150146
>>> measure(lambda: [n for n in permutation.perm(range(1, 10), lambda x: x, lambda x,y: y*10+x)])
11.453000068664551