Get in touch
or send us a question?
CONTACT

Tăng tốc python x1000 với Numba

Numba là một trình biên dịch JIT mã nguồn mở giúp dịch một tập hợp con của mã Python và NumPy thành mã máy nhanh.

Numba dịch các hàm Python sang mã máy được tối ưu hóa trong thời gian chạy bằng cách sử dụng thư viện trình biên dịch LLVM tiêu chuẩn. Các thuật toán số do Numba biên dịch trong Python có thể đạt đến tốc độ của C hoặc FORTRAN. Bạn không cần phải thay thế trình thông dịch Python, chạy một bước biên dịch riêng biệt hoặc thậm chí cài đặt trình biên dịch C / C ++. Chỉ cần áp dụng một trong các trình trang trí Numba cho hàm Python của bạn và Numba thực hiện phần còn lại.

Chúng ta sẽ đi qua 3 feature chính của Numba là jit, njit và vectorize.

1.JIT

JIT là viết tắt của just-in-time, là 1 decorator của Numba giúp chuyển code python về dạng native code.

Xét ví dụ để thấy độ hiệu quả của nó:

import time 
import random
from numba import jit
from functools import wraps
def timer(f):
    @wraps(f)
    def func(*args, **kwargs):
        t1 = time.time()
        rs = f(*args, **kwargs)
        t2 = time.time()
        print(f"Time process {f.__name__} is {(t2-t1)*1000} ms")
        return rs
    return func

@timer
def cal(nsamples):
    acc = 0
    for i in range(nsamples):
        x = random.random()
        y = random.random()
        if (x**2 + y**2) < 1.0:
            acc +=1 
    
    return 4.0 * acc / nsamples

cal(10000)

Output: Time process cal is 2.580881118774414 ms. ( Tốn 2,5 miliseconds để xứ lý)

Khi xử dụng jit

@timer
@jit()
def cal_jit(nsamples):
    acc = 0
    for i in range(nsamples):
        x = random.random()
        y = random.random()
        if (x**2 + y**2) < 1.0:
            acc +=1 
    
    return 4.0 * acc / nsamples

print("first call")
cal_jit(10000)
print("second call")
cal_jit(10000)

Output:

first call
Time process cal_jit is 294.0945625305176 ms
second call
Time process cal_jit is 0.10800361633300781 ms

Ta nhận tại lần gọi đầu tiên tốn khá nhiều thời gian để chạy. Là do khi jit cần phải tốn thời gian để xác định các thành phần trong code thuộc dạng dữ liệu nào và cấp phát bộ nhớ cho chúng. Và sau đó thì thử hiện chỉ tốn 0.1 miliseconds

Thử xét thêm 1 ví dụ khác để hiểu thêm về jit:

@timer
def dump_func():
    out = []
    for i in range(100):
        if i % 2 ==0:
            out.append(2)
        else:
            out.append('1')

    return out

@timer
@jit
def dump_func_jit():
    out = []
    for i in range(100):
        if i % 2 ==0:
            out.append(2)
        else:
            out.append('1')

    return out

dump_func()
print("first time")
dump_func_jit()
print("second time")
dump_func_jit()

Output:

Time process dump_func is 0.009298324584960938 ms
first time

Time process dump_func_jit is 436.28907203674316 ms
second time
Time process dump_func_jit is 0.09059906005859375 ms

Ngoài ra chúng ta con nhận được warning:

During: resolving callee type: Compilation is falling back to object mode WITH looplifting enabled because Function “dump_func_jit” failed type inference due to: Invalid use of BoundFunction(list.append for list(int64)) with parameters (Literalstr). BoundFunction(list.append for list(int64))
During: typing of call at test.py (71)

Ta có thể thấy tốc độ của jit sau khi chạy lần đầu không nhưng không giảm con tăng. Lý do chính là bởi khi jit thực hiện chuyển sang native code, nó đã nhận diện được mảng out của ta sẽ là một mảng int64 nhưng sau đó lại được thêm 1 dữ liệu dạng string vào dẫn đến jit buộc phải chuyền về chế độ object mode. Chế độ này nó sẽ sinh ra native code gần giống dạng code python, và khiến quá trình tính toán chậm đi khi chương trình cố gắng flexible kiểu dữ liệu.

Nếu chúng ta sửa lại như sau:

@timer
def dump_func():
    out = []
    for i in range(100):
        if i % 2 ==0:
            out.append(2)
        else:
            out.append(1)

    return out

@timer
@jit
def dump_func_jit():
    out = []
    for i in range(100):
        if i % 2 ==0:
            out.append(2)
        else:
            out.append(1)

    return out

dump_func()
print("first time")
dump_func_jit()
print("second time")
dump_func_jit()

Output:

Time process dump_func is 0.010251998901367188 ms
first time
Time process dump_func_jit is 222.58901596069336 ms
second time
Time process dump_func_jit is 0.00286102294921875 ms

Thời gian thưc hiện đã được cải thiện lên rất nhiều. Vậy nên khi sử dụng cần chú ý định dạng kiểu dữ liệu. Và Numba cung cấp 1 feature nhằm tránh tình huống này xảy ra là njit.

2.njit

Njit là cơ chế jit với đặc điểm hạn chế code được biên dịch về chế độ object mode. Khi chúng ta thực hiện đoạn code trên thay vì revert về chế độ object mode và thông báo warning thì njt sẽ báo error và không thực hiện biên dịch

Chúng ta có thể thay njit bằng các sử dụng @jit(nopython=True) cũng có tác dụng tương tự.

3.Vectorize

Numba’s vectorize cho phép các hàm Python lấy các đối số đầu vào vô hướng được sử dụng như các hàm NumPy. Tạo một ufunc NumPy truyền thống không phải là quá trình đơn giản nhất và liên quan đến việc viết một số mã C.

Chúng ta có thể xem ví dụ dưới đây để dễ hiểu hơn:

print("############## jit #################")

@timer
def dump_func():
    out = []
    for i in range(100000):
        if i % 2 ==0:
            out.append(2)
        else:
            out.append(1)

    return out

@timer
@jit
def dump_func_jit():
    out = []
    for i in range(100000):
        if i % 2 ==0:
            out.append(2)
        else:
            out.append(1)

    return out

dump_func()
print("first time")
dump_func_jit()
print("second time")
dump_func_jit()



import numpy as np
from numba import vectorize

@timer
@vectorize
def scalar_compute(num):
    if num % 2 == 0:
        return 0
    else:
        return 1

print("############## vectorize #################")
scalar_compute(10)
scalar_compute(np.array(100000))

Output:

######## jit#######

Time process dump_func is 6.188392639160156 ms
first time
Time process dump_func_jit is 230.7758331298828 ms
second time
Time process dump_func_jit is 0.7870197296142578 ms

######## vecterize#######

Time process scalar_compute is 48.999786376953125 ms
Time process scalar_compute is 0.0021457672119140625 ms

Dựa vào output ta có thể đưa ra 2 nhận xét:

  1. Với đầu vào của vectorize chúng ta có thể đưa vào bất kỳ kiểu dữ liệu số là int hoặc int array không cố định 1 kiểu dữ liệu
  2. Ta có thể thấy tốc độ của vectorize nhanh hơn hẳn so với trường hợp sử dụng jit bên trên. Lý do là trường hợp jit bên trên sử dụng biến out là mảng không có kích thước xác định, trong khi đó đối với vectorize thì đầu ra của output đã được xác định dựa vào input nên việc cấp phát bộ nhớ lữu trữ dễ dàng hơn, việc convert native code hiệu quả hơn nên tốc đọ nhanh hơn hẳn.
  3. Nếu sử dụng dụng jit để có tốc độ như vectorize trong trường hợp trên ta có thể xử dụng numpy để cố định đầu ra của hàm. out= numpy.zeros(100000) như vậy sẽ giúp tốc độ của chúng như nhau

Cuối cùng Numba cung cấp 1 công cụ hỗ trợ việc tăng hiệu tốc độ xử lý của chương trình. Ngoài các đặc trưng trên có thể tìm hiểu thêm về cách xử lý list với Numa type list, sử dụng jit(nogil=True) để tăng hiệu suất xử lý đa luồng của Numba khiến cho các tiến trình chạy thường mất hàng tiếng chỉ còn trong vài phút. Cảm ơn vì đã đọc hết đến đây :).

Tham khảo :

[1] https://numba.pydata.org/