Numba là một trình biên dịch JIT mã nguồn mở giúp dịch một tập hợp con của mã Python và NumPy thành mã máy nhanh.
Numba dịch các hàm Python sang mã máy được tối ưu hóa trong thời gian chạy bằng cách sử dụng thư viện trình biên dịch LLVM tiêu chuẩn. Các thuật toán số do Numba biên dịch trong Python có thể đạt đến tốc độ của C hoặc FORTRAN. Bạn không cần phải thay thế trình thông dịch Python, chạy một bước biên dịch riêng biệt hoặc thậm chí cài đặt trình biên dịch C / C ++. Chỉ cần áp dụng một trong các trình trang trí Numba cho hàm Python của bạn và Numba thực hiện phần còn lại.
Chúng ta sẽ đi qua 3 feature chính của Numba là jit, njit và vectorize.
1.JIT
JIT là viết tắt của just-in-time, là 1 decorator của Numba giúp chuyển code python về dạng native code.
Xét ví dụ để thấy độ hiệu quả của nó:
import time
import random
from numba import jit
from functools import wraps
def timer(f):
@wraps(f)
def func(*args, **kwargs):
t1 = time.time()
rs = f(*args, **kwargs)
t2 = time.time()
print(f"Time process {f.__name__} is {(t2-t1)*1000} ms")
return rs
return func
@timer
def cal(nsamples):
acc = 0
for i in range(nsamples):
x = random.random()
y = random.random()
if (x**2 + y**2) < 1.0:
acc +=1
return 4.0 * acc / nsamples
cal(10000)
Output: Time process cal is 2.580881118774414 ms. ( Tốn 2,5 miliseconds để xứ lý)
Khi xử dụng jit
@timer
@jit()
def cal_jit(nsamples):
acc = 0
for i in range(nsamples):
x = random.random()
y = random.random()
if (x**2 + y**2) < 1.0:
acc +=1
return 4.0 * acc / nsamples
print("first call")
cal_jit(10000)
print("second call")
cal_jit(10000)
Output:
first call
Time process cal_jit is 294.0945625305176 ms
second call
Time process cal_jit is 0.10800361633300781 ms
Ta nhận tại lần gọi đầu tiên tốn khá nhiều thời gian để chạy. Là do khi jit cần phải tốn thời gian để xác định các thành phần trong code thuộc dạng dữ liệu nào và cấp phát bộ nhớ cho chúng. Và sau đó thì thử hiện chỉ tốn 0.1 miliseconds
Thử xét thêm 1 ví dụ khác để hiểu thêm về jit:
@timer
def dump_func():
out = []
for i in range(100):
if i % 2 ==0:
out.append(2)
else:
out.append('1')
return out
@timer
@jit
def dump_func_jit():
out = []
for i in range(100):
if i % 2 ==0:
out.append(2)
else:
out.append('1')
return out
dump_func()
print("first time")
dump_func_jit()
print("second time")
dump_func_jit()
Output:
Time process dump_func is 0.009298324584960938 ms
first time
Time process dump_func_jit is 436.28907203674316 ms
second time
Time process dump_func_jit is 0.09059906005859375 ms
Ngoài ra chúng ta con nhận được warning:
During: resolving callee type: Compilation is falling back to object mode WITH looplifting enabled because Function “dump_func_jit” failed type inference due to: Invalid use of BoundFunction(list.append for list(int64)) with parameters (Literalstr). BoundFunction(list.append for list(int64))
During: typing of call at test.py (71)
Ta có thể thấy tốc độ của jit sau khi chạy lần đầu không nhưng không giảm con tăng. Lý do chính là bởi khi jit thực hiện chuyển sang native code, nó đã nhận diện được mảng out của ta sẽ là một mảng int64 nhưng sau đó lại được thêm 1 dữ liệu dạng string vào dẫn đến jit buộc phải chuyền về chế độ object mode. Chế độ này nó sẽ sinh ra native code gần giống dạng code python, và khiến quá trình tính toán chậm đi khi chương trình cố gắng flexible kiểu dữ liệu.
Nếu chúng ta sửa lại như sau:
@timer
def dump_func():
out = []
for i in range(100):
if i % 2 ==0:
out.append(2)
else:
out.append(1)
return out
@timer
@jit
def dump_func_jit():
out = []
for i in range(100):
if i % 2 ==0:
out.append(2)
else:
out.append(1)
return out
dump_func()
print("first time")
dump_func_jit()
print("second time")
dump_func_jit()
Output:
Time process dump_func is 0.010251998901367188 ms
first time
Time process dump_func_jit is 222.58901596069336 ms
second time
Time process dump_func_jit is 0.00286102294921875 ms
Thời gian thưc hiện đã được cải thiện lên rất nhiều. Vậy nên khi sử dụng cần chú ý định dạng kiểu dữ liệu. Và Numba cung cấp 1 feature nhằm tránh tình huống này xảy ra là njit.
2.njit
Njit là cơ chế jit với đặc điểm hạn chế code được biên dịch về chế độ object mode. Khi chúng ta thực hiện đoạn code trên thay vì revert về chế độ object mode và thông báo warning thì njt sẽ báo error và không thực hiện biên dịch
Chúng ta có thể thay njit bằng các sử dụng @jit(nopython=True) cũng có tác dụng tương tự.
3.Vectorize
Numba’s vectorize cho phép các hàm Python lấy các đối số đầu vào vô hướng được sử dụng như các hàm NumPy. Tạo một ufunc NumPy truyền thống không phải là quá trình đơn giản nhất và liên quan đến việc viết một số mã C.
Chúng ta có thể xem ví dụ dưới đây để dễ hiểu hơn:
print("############## jit #################")
@timer
def dump_func():
out = []
for i in range(100000):
if i % 2 ==0:
out.append(2)
else:
out.append(1)
return out
@timer
@jit
def dump_func_jit():
out = []
for i in range(100000):
if i % 2 ==0:
out.append(2)
else:
out.append(1)
return out
dump_func()
print("first time")
dump_func_jit()
print("second time")
dump_func_jit()
import numpy as np
from numba import vectorize
@timer
@vectorize
def scalar_compute(num):
if num % 2 == 0:
return 0
else:
return 1
print("############## vectorize #################")
scalar_compute(10)
scalar_compute(np.array(100000))
Output:
Time process dump_func is 6.188392639160156 ms
first time
Time process dump_func_jit is 230.7758331298828 ms
second time
Time process dump_func_jit is 0.7870197296142578 ms
Time process scalar_compute is 48.999786376953125 ms
Time process scalar_compute is 0.0021457672119140625 ms
Dựa vào output ta có thể đưa ra 2 nhận xét:
Cuối cùng Numba cung cấp 1 công cụ hỗ trợ việc tăng hiệu tốc độ xử lý của chương trình. Ngoài các đặc trưng trên có thể tìm hiểu thêm về cách xử lý list với Numa type list, sử dụng jit(nogil=True) để tăng hiệu suất xử lý đa luồng của Numba khiến cho các tiến trình chạy thường mất hàng tiếng chỉ còn trong vài phút. Cảm ơn vì đã đọc hết đến đây :).
Tham khảo :
[1] https://numba.pydata.org/
You need to login in order to like this post: click here
YOU MIGHT ALSO LIKE