PythonのJITがすごい

2024年1月15日

linux + python + opencvでロゴ検出類似度であれこれしようとしてるときに
pythonのNumbaのjitがすごい!みたいな記事を見かけたので試してみました。

とりあえず試してみる

プログラミングは専門外なので詳しくないですが

こんなやつにモンテカルロ法を試してみることにしました。

from numba import jit ,njit
from numba import prange

@jit()
def function4():
    hoge
    return

defの前に@jitをつけるだけ。

結果(N=10**8)

 

エントリー時間
C2.45秒
func128.00秒
func2(1+jit)3.09秒
func332.8秒
func4(3+jit)1.41秒
func5(3+jit+prange)0.66秒
func6(numpy)2.67秒

@jitつけるだけでお手軽で速すぎなのでは。
pythonのループは遅い(とくにfor)ってきいてたけどデフォは確かに遅かった。
numpyはよく分かってないのでこれでいいのかは不明。

間違いがあれば教えてください。

 

使ったもの1

#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <math.h>

int main(void){

  long int i;
  long int count = 0;
  long int max = pow(10,8);
  double x,y,pi;
  clock_t start,end;

  srand(time(NULL));
  start = clock();
 
  for(i=0;i<max;i++){
    x = (double)rand()/RAND_MAX;
    y = (double)rand()/RAND_MAX;
    if(y<1/(1+x*x))
      count++;
  }
  pi = (double)count / max * 4;
  end = clock();
  printf("%f\n", pi);
  printf("%.2f秒かかりました\n",(double)(end-start)/CLOCKS_PER_SEC);
  return 0;
}

 

使ったもの2

#!/usr/bin/python3
# -*- coding: utf-8 -*-
import time
import random
import numpy as np

from numba import jit ,njit
from numba import prange
N=10**8

def function1():
	A=[1 if 1/(1+(random.random())**2) > random.random() else 0 for i in range(N) ]
	ans=4*A.count(1)/N
	return

@jit
def function2():
	A=[1 if 1/(1+(random.random())**2) > random.random() else 0 for i in range(N) ]
	ans=4*A.count(1)/N
	return

def g(x):
    return 1.0/(1.0+x**2)

def function3():
	count = 0
	for i in range(N):
		x ,y= random.random(),random.random()
		if y < g (x):
			count +=1
	ans=4*count/N
	print(ans)
	return

@jit('f8(f8)')
def f(x):
    return 1.0/(1.0+x**2)

@jit()
def function4():
	count = 0
	for i in range(N):
		x ,y= random.random(),random.random()
		if y < f(x):
			count +=1
	ans=4*count/N
	print(ans)
	return

@jit(nopython=True, parallel=True)
def function5():
	count = 0
	for i in prange(N):
		x ,y= random.random(),random.random()
		if y < f(x):
			count +=1
	ans=4*count/N
	print(ans)
	return

def function6():
	x, y = np.random.rand(N), np.random.rand(N)
	c = (y<1.0/(1.0+x**2) ).sum()
	ans=4*c/N
	print(ans)
	return

python

Posted by neff