快轉到主要內容

幫 Celery 加上 type hint (ParamSpec)

·1222 字·3 分鐘
Denny Cheng / 月月冬瓜
作者
Denny Cheng / 月月冬瓜
獸控兼工程師兼鍵盤武術家

問題
#

Celery 作者可能有反社會人格,當你使用此 lib,所有的 type hint 都會消失。

官網提供的範例為例。

給定下列宣告

from celery import Celery

app = Celery('tasks', broker='pyamqp://guest@localhost//')

@app.task
def add(x: int, y: int) -> int:
    return x + y

當你開始使用時,會發現 .delay 根本沒被定義。

x = add.delay(1,2) # 紅線: Cannot access attribute "delay" for class "FunctionType" Attribute "delay" is unknown

這就導致當 task 多了一個 parameter,return type 變動,即使使用者忘記改 caller,也不會被提醒。

ParamSpec
#

decorator 介紹
#

雖然不知道 Celery 作者是多低能才會寫出這種智障設計,但還是有一些自救手段。

首先先閱讀 Python typing lib 中的 ParamSpec

這個功能是不定參數的好夥伴,可以幫一些底層為 *args, **kwargs 但其實有固定參數組合的東西加上 type hint。
以一個最簡單的例子

def deco(func):
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    return wrapper 

@deco
def add1(a: int, b: int):
    return a + b

def _add2(a: int, b: int):
    return a + b

add2 = deco(_add2)

可以看一下當錯誤使用時會發生什麼事

add1(1,2,3) # 紅線: Expected 2 positional arguments
add2(1,2,3) # 無紅線,pyright 無法分析出 add2 其實只有兩個參數

add1add2 其實語意是相同的,但 add2 因為沒使用 decorator 語法,所以 pyright 無法進行推導。
但這並不是理所當然的,如果標記正確,不論把 deco 當 decorator 使用,還是當作一般 function 使用,都可以得到正確結果。
只能在 decorator 得到正確結果是因為 pyright 偷吃步幫你省掉一些工了。

標記正確型別
#

該如何標記正確的型別?
首先 def deco(func): 本身宣告就有問題,func 應該只允許 Callable,不允許其他型態,否則使用 func(*args, **kwargs) 就會報錯。

但該如何標記此 CallableCallable 的兩個參數分別代表 parameter 和 回傳值,舉例來說

  1. 如果裝飾 def add(a: int, b: int) -> int: ... 的話,那型別為 Callable[[int, int], int]
  2. 如果裝飾 def concat(a: str, b: str) -> str: ... 的話,那型別為 Callable[[str, str], str]

但此處,我們希望他可以裝飾任何東西,因為 decorator 很常見的狀況是作為 logger、exception handler、retry 存在。只會在本身做一些額外事項,接下來就把參數原封不動的送進原 function 中。因此我們需要一個更靈活的方式來標記這類情況,根據裝飾的類型不同,輸出的參數也不同。這時候就會用到剛剛提到的 ParamSpec

from typing import Callable, ParamSpec, TypeVar


P = ParamSpec("P")
T = TypeVar("T")

def deco(func: Callable[P, T]) -> Callable[P, T]:
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
        return func(*args, **kwargs)

    return wrapper

首先觀看 def deco(func: Callable[P, T]) -> Callable[P, T]:
利用 ParamSpec 和 Typevar 來代指參數和回傳值,這句話的意思即是說輸入和輸出皆為 Callble,且這兩個 Callable 的參數跟回傳值皆相同。

接著觀看 def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:,這個 function 的解讀為

  1. *args**kwargs 要跟傳進來的 func 可接受的值一模一樣。
  2. return type 也要是傳進來的 func 的 return type。
  3. 承1,2,就可以發現 def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 實際上的型別就是 Callable[P, T]

當作完此型態標注之後,再使用剛剛的的例子

add1(1,2,3) # 紅線: Expected 2 positional arguments
add2(1,2,3) # 紅線: Expected 2 positional arguments

就會發現兩個 function 都能正確解讀了。

標記 Celery
#

學會使用 ParamSpec 後,對 celery 加上 type hint 就簡單多了。

T = TypeVar("T", covariant=True)
P = ParamSpec("P")

class CeleryTaskHint(Protocol[P, T]):
    def delay(self, *args: P.args, **kwargs: P.kwargs) -> T: ...

    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...


def hint(func: Callable[P, T]) -> CeleryTaskHint[P, T]:
    return func  # type: ignore

解讀:傳進一個 Callable,傳出一個 CeleryTaskHint 的 protocol,他可以作為 function 或是 .delay 使用,這兩個用法的 args 和 kwargs 的參數接跟傳進來的 function 一樣。

使用方法

@hint
@app.task
def add(x: int, y: int) -> int:
    return x + y

此時再使用就能看到正確的錯誤訊息了

x = add.delay(1, "2") # 紅線:Argument of type "Literal['2']" cannot be assigned to parameter "b" of type "int" "Literal['2']" is not assignable to "int" 
x = add.delay(3, 4, 5) # 紅線:Expected 2 positional arguments