平方数の判定を高速にやりたい

ある数が平方数かどうか(整数を2乗した整数かどうか)判定して、平方数ならその平方根を得たい。
しかも、大きな数(例えば、3^256)でも正しく判定できることが条件。
という観点で試行錯誤したのでメモ。

考えた後で思ったけどPARI/GPとかになら組み込みで超高速な判定が実装されてそう・・・

方法1. 素朴な実装

nの平方根の整数部分を浮動小数点数で求めて、2乗したらnになるかどうか調べる。

  • ○簡単
  • ○速い
  • ×大きな数には使えない(doubleの精度の問題)。例えば17003146110244618067556674657992186608423729 ( 4123487129874982169223^2 ) の判定には失敗する
#! /usr/bin/python2.7
def try_square_root_naive(n):
    m = int(n**.5)
    return m if abs(m*m - n) < 1e-6 else None

方法2. 2進化10進数を使った任意精度計算で方法1と同じ事をやる

  • 浮動小数点の精度を上げることで方法1.より大きな値まで正確に判定できる
  • ×どれだけの精度をとるべきかわからない。数値の桁数*2倍の精度で計算してもある程度数値が大きくなると誤判定が起きてしまう。。。
  • ×高い精度で計算すると非常に遅い。
import decimal
a = decimal.Decimal('.5')
def try_square_root_decimal(n):
    decimal.getcontext().prec = len(str(n)) * 2 # 適当に精度を設定
    dn = decimal.Decimal(n)
    pdn = pow(dn, a)
    if pdn - int(pdn) < 1e-6:
        return pdn

2分探索

[0,n]の範囲で、m*m = nになるmを2分探索で見つける。
python標準ライブラリのbisectで2分探索ができるのでそれを使う。

  • ○数値が大きくなっても対数時間で計算できるはず
  • ○大きな数も扱え・・・無い!

⇒ bisectはシーケンスから値を探索する用途を想定されているため、intを超える範囲の数値はOverflowErrorを起こす(っぽい)
(pythonの実装上の問題。)

class SquareRoot(object):
    def __init__(self, n):
        self.n = n
    def __getitem__(self, index):
        return index * index
    def __len__(self):
        return self.n

def try_square_root(n2):
    n = bisect.bisect_left(SquareRoot(n2), n2)
    return n if n*n == n2 else None
"""
>>> try_square_root(144) # 12^2
12
>>> try_square_root(65536) # 256^2
256
>>> try_square_root(17003146110244618067556674657992186608423729) # 4123487129874982169223^2
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 2, in try_square_root
OverflowError: long int too large to convert to int
"""

2分探索(改良版)

大きな数でも使える2分探索。

# 2分探索で以下の条件を満たすmを返す。
# 見つからない場合はNoneを返す
#    f(m) = n, m < [0, n]
# 巨大なnにも対応。これはbisectではできない。
def bisearch(f, n):
    l = 0
    r = n
    fl = f(l)
    fr = f(r)
    while l <= r:
        m = (l + r) // 2
        fm = f(m)
        #print l,m,r,fl,fm,fr,n
        if fm == n:
            return m
        elif fm > n:
            r = m - 1
            fr = f(r)
        else:
            l = m + 1
            fl = f(l)

def try_square_root(n2):
    return bisearch(lambda n:n*n, n2)

こんなもんかな〜

SQLアンチパターン

SQLアンチパターン