Pythonで作るn次元n目並べ その2:1次元n目並べ速度計測

前回作った1次元n目並べをNumPyに書き換えて速度比較してみました。

条件

盤面のサイズやn目並べのnが小さいと計算時間がほぼ0になってしまうので、n目並べの盤面の大きさは1,000,000で5,000個連続した場合勝ちという条件で実行して、盤面をリストで作った場合とNumPyで作った場合の勝ち判定の処理で比較しました。計測回数は先攻後攻を30ターンで計60回です。



結果

ListNumPy
turnplayerputtimeturnplayerputtime
01898540.005984 017910.003990
02698330.003989 02753950.001995
1137320.004986 11667280.002993
12583470.003989 12710690.001997
  ・・・  ・・・
281616840.004987 281140860.001995
282528600.003991 282144780.001994
291209810.003989 291844750.001995
292631480.004987 292146590.001995
ListNumPy
max0.005984 0.003990
min0.003988 0.001994
range0.001996 0.001996
average0.004339 0.002361
0.001540 0.001555

平均でみると0.004339/0.002361=1.8倍くらい速くなりました。それとグラフで見てみると一回目は遅いのですかね。

所感

NumPyは使ったことがほとんどありませんでした。使ってみた感じは、速度は速いし、便利なメソッドがたくさんあり、多次元化するときにも有利かもしれません。次回からはNumPyで作っていこうと思います。





リスト版ソースコード

import random
import time

def choice_index_randomly(arr):
    indexes = [ i for i, d in enumerate(arr) if d==0 ]
    return random.choice(indexes)

def is_win(arr, index, connect_length):
    s = index - connect_length + 1
    if s < 0:
        s = 0
    e = index + connect_length + 1
    if e > len(arr):
        e = len(arr)
    str1 = ''.join( [str(i) for i in arr[s:e]] )
    str2 = ''.join( [str(arr[index]) for i in range(connect_length)] )
    if str2 in str1:
        return True
    return False

def output_text(ls):
    text = '\n'.join( ['\t'.join([str(k) for k in j]) for j in ls] )
    with open('output.txt', mode='w') as f:
        f.write(text)

def main():

    field_size = 100000
    connect_length = 5000

    # create field
    arr = [ 0 for i in range(field_size) ]

    ls = [['turn', 'player', 'put', 'time']]

    # put pieces randomly
    for i in range(30):

        # player 1
        index = choice_index_randomly(arr)
        arr[index] = 1

        start_time = time.time()
        frag = is_win(arr, index, connect_length)
        end_time = time.time()
        t = end_time-start_time
        ls.append([i, 1, index, t])

        if frag:
            print('player 1 is won!!!')
            output_text(ls)
            return

        # player 2
        index = choice_index_randomly(arr)
        arr[index] = 2

        start_time = time.time()
        frag = is_win(arr, index, connect_length)
        end_time = time.time()
        t = end_time-start_time
        ls.append([i, 2, index, t])

        if frag:
            print('player 2 is won!!!')
            output_text(ls)
            return
    output_text(ls)

if __name__ == '__main__':
    main()



NumPy版ソースコード

import numpy as np
import random
import time

def choice_index_randomly(arr):
    indexes = np.where(arr==0)[0]
    return np.random.choice(indexes)

def is_win(arr, index, connect_length):
    change = (arr[1:]==arr[:-1])

    left = np.arange(len(change))
    left[change>0] = 0
    np.maximum.accumulate(left, out=left)

    right = np.arange(len(change))
    right[change[::-1]>0]=0
    np.maximum.accumulate(right, out=right)
    right = len(change) - right[::-1] - 1

    result = np.zeros_like(arr)
    result[:-1] += right
    result[1:] -= left
    result[-1] = 0

    return np.max(result[arr == arr[index]]) >= connect_length

def output_text(ls):
    text = '\n'.join( ['\t'.join([str(k) for k in j]) for j in ls] )
    with open('output.txt', mode='w') as f:
        f.write(text)

def main():

    players = range(1,3)
    field_size = 100000
    connect_length = 5000

    # create field
    arr = np.zeros(field_size, int)

    ls = [['turn', 'player', 'put', 'time']]

    for i in range(30):

        for player in players:
            index = choice_index_randomly(arr)
            arr[index] = player

            start_time = time.time()
            frag = is_win(arr, index, connect_length)
            end_time = time.time()
            t = end_time-start_time
            ls.append([i, player, index, t])

            if frag:
                print('player', player, 'is won!!!')
                output_text(ls)
                return
    output_text(ls)

if __name__ == '__main__':
    main()

コメント

タイトルとURLをコピーしました