Shape typing numpy with pyright and variadic generics
February 27th, 2023 · 7 min read
When doing any sort of tensor/array computation in python
(via numpy
,
pytorch
, jax
, or other), it's more frequent than not to encounter shape
errors like the one below
1import numpy as np
3size1 = (2,3)
4size2 = (4,3)
6M1 = np.random.random(size=size1)
7M2 = np.random.random(size=size2)
9try:
10 print(np.dot(M1,M2))
11except Exception as e:
12 print(e)
1shapes (2,3) and (4,3) not aligned: 3 (dim 1) != 4 (dim 0)
And most of the time, these kind of errors boil down to something like
accidentally forgetting to do a reshape or transpose like so.
1import numpy as np
3size1 = (2,3)
4size2 = (4,3)
6M1 = np.random.random(size=size1)
7M2 = np.random.random(size=size2).T
9try:
10 print(np.dot(M1,M2))
11except Exception as e:
12 print(e)
1[[0.68034751 0.49931699 0.71019458 0.22819836]
2 [0.56876626 0.5991015 1.19438961 0.3778173 ]]
And while this is a mild case, shape bugs like these become more frequent as
operations grow more complex and as more dimensions are involved.
Here's a slightly more complex example of a Linear
implementation in numpy
with a subtle shape bug.
1def Linear(A, x, b):
3 Takes matrix A (m x n) times a vector x (n x 1) and
4 adds a bias. The resulting ndarray is then ravelled
5 into a vector of size (m).
7 Ax = np.dot(A, x)
8 Axb = np.add(Ax, b)
9 return np.ravel(Axb)
11A = np.random.random(size=(4,4))
12x = np.random.random(size=(4,1))
13b = np.random.random(size=(4))
15result = Linear(A, x, b)
16print(result)
17print(result.shape)
1[1.56441647 1.47650364 1.57303598 1.14131074 1.18493577 1.09702295
2 1.19355528 0.76183004 1.57644585 1.48853302 1.58506536 1.15334012
3 1.18260885 1.09469603 1.19122836 0.75950312]
4(16,)
The docstring of Linear
clearly says the result should be size m
(or
4
). But why then did we end up with a vector of size 16
? If we dig into
each function we will eventually find that our problem is in how numpy
handles an ndarray
of a different shape.
If we break down Linear
, after np.dot
we have an ndarray
of shape
(4,1)
of which we do np.add
with a vector of shape (4)
. And here lies
our bug. We might naturally think that np.add
will do this addition element
wise, but instead we fell into an array broadcasting trap. Array broadcasting
are sets of rules numpy
uses to determine how to do arithmetic on different
shaped ndarrays
. So instead of doing our computation element wise, numpy
interprets this as doing a broadcast operation of addition, resulting in a
(4,4)
matrix, which subsequently gets "raveled" into a size 16
vector.
Now to fix this is easy, we just need to initialize our b
variable to be of
shape (4,1)
so numpy
will interpret the np.add
as an element wise
addition.
1def Linear(A, x, b):
3 Takes matrix A (m x n) times a vector x (n x 1) and
4 adds a bias. The resulting ndarray is then ravelled
5 into a vector of size (m).
7 Ax = np.dot(A, x)
8 Axb = np.add(Ax, b)
9 return np.ravel(Axb)
11A = np.random.random(size=(4,4))
12x = np.random.random(size=(4,1))
13b = np.random.random(size=(4,1))
15result = Linear(A, x, b)
16print(result)
17print(result.shape)
1[1.97154938 1.84206823 2.94043945 1.7510802 ]
2(4,)
We've solved the problem, but how can we be smarter to prevent this error from
happening again?
Existing ways to stop shape bugs
The simplest way we can try to stop this shape bug is with good docs. Ideally
we should always have good docs, but we can make it a point to include what
the shape expectations are like so:
1def Linear(A, x, b):
3 Args:
4 A: ndarray of shape (M x N)
5 x: ndarray of shape (N x 1)
6 b: ndarray of shape (M x 1)
8 Returns:
9 Linear output ndarray of shape (M)
10 """
11 Ax = np.dot(A, x) # Shape (M x 1)
12 Axb = np.add(Ax, b) # (M x 1) + (M x 1)
13 return np.ravel(Axb) # Shape (M)
Now while informative, nothing is preventing us from encountering the same bug
again. The only benefit this gives us, is making the debugging process a
bit easier.
We can do better.
Another approach in addition to good docs that's more of a preventative action
is to use assertions. By sprinkling assert
throughout Linear
with an
informative error message, we can "fail early" and start debugging like so:
1def Linear(A, x, b):
3 Args:
4 A: ndarray of shape (M x N)
5 x: ndarray of shape (N x 1)
6 b: ndarray of shape (M x 1)
8 Returns:
9 Linear output ndarray of shape (M)
10 """
11 assert len(A.shape) == 2, f"A must be of dim 2, not {len(A.shape)}"
12 Am, An = A.shape
14 assert x.shape == (An, 1), f"X must be shape ({An}, 1) to do dot"
15 Ax = np.dot(A, x) # Shape (M x 1)
17 assert b.shape == (Am, 1), f"Bias term must be shape ({Am}, 1)"
18 result = np.add(Ax, b) # (M x 1) + (M x 1)
20 ravel_result = np.ravel(result)
21 assert ravel_result.shape == (Am,), f"Uh oh, ravel result is shape {ravel_result.shape} and not {(Am,)}"
22 return ravel_result
At every step of this function we do an assert
to make sure all the
ndarray
shapes are what we expect.
As a result Linear
is a bit "safer". But compared to what we had originally,
this approach is much less readable. We also inherit some of the baggage that
comes with runtime error checking like:
Incomplete checking: Have we checked all expected shape failure modes?
Slow debugging cycles: How many refactor->run cycles will we have to do
pass the checks?
Additional testing: Do we have to update our tests cover our runtime error
checks?
Overall runtime error checking is not a bad thing. In most cases it's very
necessary! But when it comes to shape errors, we can leverage an additional
approach, static type checking.
Even though python
is a dynamically typed language, in python>=3.5
the
typing
module was introduced to enable static type checkers to validate type
hinted python
code. (See this video for more details)
Over time many third party libraries (like numpy
) have started to type hint
their codebases which we can use to our benefit.
In order to help us prevent shape errors, let's see what typing capabilities
exist in numpy
.
dtype
typing numpy
arrays
As of writing this post, numpy==v1.24.2
only supports typing on an
ndarray
's dtype
(uint8
, float64
, etc.).
Using numpy
's existing type hinting tooling, here's how we would include
dtype
type information to our Linear
example (note: there is an
intentional type error)
1from typing import TypeVar
3import numpy as np
4from numpy.typing import NDArray
6GenericType = TypeVar("GenericType", bound=np.generic)
9def Linear(
10 A: NDArray[GenericType],
11 x: NDArray[GenericType],
12 b: NDArray[GenericType],
13) -> NDArray[GenericType]:
14 """
15 Args:
16 A: ndarray of shape (M x N)
17 x: ndarray of shape (N x 1)
18 b: ndarray of shape (M x 1)
20 Returns:
21 Linear output ndarray of shape (M)
22 """
23 assert len(A.shape) == 2, f"A must be of dim 2, not {len(A.shape)}"
24 Am, An = A.shape
26 assert x.shape == (An, 1), f"X must be shape ({An}, 1) to do dot"
27 Ax: NDArray[GenericType] = np.dot(A, x) # Shape (M x 1)
29 assert b.shape == (Am, 1), f"Bias term must be shape ({Am}, 1)"
30 result: NDArray[GenericType] = np.add(Ax, b) # (M x 1) + (M x 1)
32 ravel_result: NDArray[GenericType] = np.ravel(result)
33 assert ravel_result.shape == (Am,), f"Uh oh, ravel result is shape {ravel_result.shape} and not {(Am,)}"
34 return ravel_result
37A: NDArray[np.float64] = np.random.standard_normal(size=(10, 10))
38x: NDArray[np.float64] = np.random.standard_normal(size=(10, 1))
39b: NDArray[np.float32] = np.random.standard_normal(size=(10, 1))
40y: NDArray[np.float64] = Linear(A, x, b)
41print(y)
42print(y.dtype)
1[-0.53752298 3.67833386 -1.43092158 -2.58647295 2.44053318 -1.9393581
2 -0.23397058 0.79320484 2.16039462 -0.10612777]
3float64
Even though this code is "runnable" and doesn't produce an error, a type
checker like pyright
tells us a different story.
1pyright linear_bad_typing.py
1No configuration file found.
2No pyproject.toml file found.
3stubPath /mnt/typings is not a valid directory.
4Assuming Python platform Linux
5Searching for source files
6Found 1 source file
7pyright 1.1.299
8/mnt/linear_bad_typing.py
9 /mnt/linear_bad_typing.py:40:26 - error: Expression of type "ndarray[Any, dtype[float64]]" cannot be assigned to declared type "NDArray[float32]"
10 "ndarray[Any, dtype[float64]]" is incompatible with "NDArray[float32]"
11 TypeVar "_DType_co@ndarray" is covariant
12 "dtype[float64]" is incompatible with "dtype[float32]"
13 TypeVar "_DTypeScalar_co@dtype" is covariant
14 "float64" is incompatible with "float32" (reportGeneralTypeIssues)
15 /mnt/linear_bad_typing.py:41:39 - error: Argument of type "NDArray[float32]" cannot be assigned to parameter "b" of type "NDArray[GenericType@Linear]" in function "Linear"
16 "NDArray[float32]" is incompatible with "NDArray[float64]"
17 TypeVar "_DType_co@ndarray" is covariant
18 "dtype[float32]" is incompatible with "dtype[float64]"
19 TypeVar "_DTypeScalar_co@dtype" is covariant
20 "float32" is incompatible with "float64" (reportGeneralTypeIssues)
212 errors, 0 warnings, 0 informations
22Completed in 0.606sec
pyright
has noticed that when we create our b
variable, we gave it a
dtype
type that is incompatible with np.random.standard_normal
.
Now we know to adjust the type hint of b
to be in line with the dtype
that
is expected of np.random.standard_normal
(NDArray[np.float64]
).
Shape typing numpy
arrays
While dtype
typing is great, it's not the most useful for preventing shape
errors (like from our original example).
Ideally it would be great if in addition to a dtype
type, we can also
include information about an ndarray
's shape to do shape typing.
Shape typing is a technique used to annotate information about the
dimensionality and size of an array. In the context of numpy
and the
python
type hinting system, we can use shape typing catch shape errors
before runtime.
For more information about shape typing checkout this google doc on a shape
typing syntax proposal by Matthew Rahtz, Jörg Bornschein, Vlad Mikulik, Tim
Harley, Matthew Willson, Dimitrios Vytiniotis, Sergei Lebedev, Adam Paszke.
As we've seen, numpy
's NDArray
currently only supports dtype
typing and
doesn't have any of this kind of shape typing ability. But why is that? If we
dig into the definition of the NDArray
type:
1ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
3if TYPE_CHECKING or sys.version_info >= (3, 9):
4 _DType = np.dtype[ScalarType]
5 NDArray = np.ndarray[Any, np.dtype[ScalarType]]
6else:
7 _DType = _GenericAlias(np.dtype, (ScalarType,))
8 NDArray = _GenericAlias(np.ndarray, (Any, _DType))
And follow the definition of np.ndarray
...
1class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
We can see that it looks like numpy
uses a Shape
type already! But
unfortunately if we look at the definition for this ...
1# TODO: Set the `bound` to something more suitable once we
2# have proper shape support
3_ShapeType = TypeVar("_ShapeType", bound=Any)
4_ShapeType2 = TypeVar("_ShapeType2", bound=Any)
😭 Looks like we're stuck with Any
which doesn't add any useful shape
information on our types.
Luckily for us, we don't have to wait for shape support in numpy
. PEP 646 has
the base foundation for shape typing and has already been accepted into python==3.11
! And it's supported by pyright
! Theoretically these two things give
us most of the ingredients to do basic shape typing.
Now this blog post isn't about the details of PEP 646 or variadic
generics. Understanding PEP 646 will help, but it's not needed to understand
the rest of this post.
In order to add rudimentary shape typing to numpy
we can simply change the
Any
type in the NDArray
type definition to an unpacked variadic generic
like so:
1ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
2Shape = TypeVarTuple("Shape")
4if TYPE_CHECKING or sys.version_info >= (3, 9):
5 _DType = np.dtype[ScalarType]
6 NDArray = np.ndarray[*Shape, np.dtype[ScalarType]]
7else:
8 _DType = _GenericAlias(np.dtype, (ScalarType,))
9 NDArray = _GenericAlias(np.ndarray, (Any, _DType))
Doing so allows us to fill in a Tuple
based type (indicating shape) in an
NDArray
alongside a dtype
type. And shape typing with Tuple
's enables us
define function overloads which describe to a type checker the possible ways a
function can change the shape of an NDArray
.
Let's look at an example of using these concepts to type a wrapper function
for np.random.standard_normal
from our Linear
example with an intentional
type error:
1import numpy as np
2from numpy.typing import NDArray
3from typing import Tuple, TypeVar, Literal
5# Generic dimension sizes types
6T1 = TypeVar("T1", bound=int)
7T2 = TypeVar("T2", bound=int)
8T3 = TypeVar("T3", bound=int)
10# Dimension types represented as typles
11Shape = Tuple
12Shape1D = Shape[T1]
13Shape2D = Shape[T1, T2]
14Shape3D = Shape[T1, T2, T3]
15ShapeND = Shape[T1, ...]
16ShapeNDType = TypeVar("ShapeNDType", bound=ShapeND)
18def rand_normal_matrix(shape: ShapeNDType) -> NDArray[ShapeNDType, np.float64]:
19 """Return a random ND normal matrix."""
20 return np.random.standard_normal(size=shape)
22# Yay correctly typed 2x2x2 cube!
23LENGTH = Literal[2]
24cube: NDArray[Shape3D[LENGTH, LENGTH, LENGTH], np.float64] = rand_normal_matrix((2,2,2))
25print(cube)
27SIDE = Literal[4]
29# Uh oh the shapes won't match!
30square: NDArray[Shape2D[SIDE, SIDE], np.float64] = rand_normal_matrix((3,3))
31print(square)
Notice here there are no assert
statements. And instead of several comments
about shape, we indicate shape in the type hint.
Now while this code is "runnable", pyright
will tell us something else:
1py -m pyright bad_shape_typing.py --lib
1No configuration file found.
2No pyproject.toml file found.
3Assuming Python platform Linux
4Searching for source files
5Found 1 source file
6pyright 1.1.299
7/mnt/bad_shape_typing.py
8 /mnt/bad_shape_typing.py:30:71 - error: Argument of type "tuple[Literal[3], Literal[3]]" cannot be assigned to parameter "shape" of type "ShapeNDType@rand_normal_matrix" in function "rand_normal_matrix"
9 Type "Shape2D[SIDE, SIDE]" cannot be assigned to type "tuple[Literal[3], Literal[3]]" (reportGeneralTypeIssues)
101 error, 0 warnings, 0 informations
11Completed in 0.528sec
pyright
is telling us we've incorrectly typed cube
and that it's
incompatible with a 3x3
shape. Now we know we need to go back and fix the
type to what a type checker should expect.
Huzzah shape typing!!
Moar numpy
shape typing!
Now that we have shape typed one function, let's step it up a notch. Let's try
typing each numpy
function in our Linear
example to include shape
types. We've already typed np.random.standard_normal
, so next let's do
np.dot
.
If we look at the docs for np.dot
there are 5 type cases it supports.
Both arguments as 1D
arrays
Both arguments are 2D
arrays (resulting in a matmul
)
Either arguments are scalars
Either argument is a ND
array and the other is a 1D
array
One argument is ND
array and the other is MD
array
We can implement these cases as follows
1ShapeVarGen = TypeVarTuple("ShapeVarGen")
3@overload
4def dot(x1: NDArray[Shape1D[T1], GenericDType], x2: NDArray[Shape1D[T1], GenericDType], /) -> GenericDType:
8@overload
9def dot(
10 x1: NDArray[Shape[T1, *ShapeVarGen], GenericDType], x2: NDArray[Shape1D[T1], GenericDType], /
11) -> NDArray[Shape[*ShapeVarGen], GenericDType]:
12 ...
15@overload
16def dot(
17 x1: NDArray[Shape2D[T1, T2], GenericDType],
18 x2: NDArray[Shape2D[T2, T3], GenericDType],
20) -> NDArray[Shape2D[T1, T3], GenericDType]:
21 ...
24@overload
25def dot(x1: GenericDType, x2: GenericDType, /) -> GenericDType:
26 ...
29def dot(x1, x2):
30 return np.dot(x1, x2)
The only case we can't implement is an ND
dimensional array with an MD
dimensional array. Ideally we would try implementing it like so:
1ShapeVarGen1 = TypeVarTuple("ShapeVarGen1")
2ShapeVarGen2 = TypeVarTuple("ShapeVarGen2")
4@overload
5def dot(
6 x1: NDArray[Shape[*ShapeVarGen1, T1], GenericDType], x2: NDArray[Shape[*ShapeVarGen2, T1, T2], GenericDType], /
7) -> NDArray[Shape[*ShapeVarGen1, *ShapeVarGen2], GenericDType]:
But currently using multiple type variable tuples is not allowed. If you know
of another way to cover this case let me know! Luckily for our Linear
use
case, it only uses scalars, vectors, and matrices which is covered by our four
overloads.
Here's how we would use these dot
overloads to do the dot product between a
2x3
matrix and a 3x2
matrix with type hints:
1import numpy as np
2from numpy.typing import NDArray
3from numpy_shape_typing.dot import dot
4from numpy_shape_typing.types import ShapeNDType, Shape2D
5from numpy_shape_typing.rand import rand_normal_matrix
7from typing import Literal
9ROWS = Literal[2]
10COLS = Literal[3]
11A: NDArray[Shape2D[ROWS, COLS], np.float64] = rand_normal_matrix((2,3))
12B: NDArray[Shape2D[COLS, ROWS], np.float64] = rand_normal_matrix((3,2))
13C: NDArray[Shape2D[ROWS, ROWS], np.float64] = dot(A, B)
14print(C)
And if we check with pyright
:
1py -m pyright good_dot.py --lib
1No configuration file found.
2No pyproject.toml file found.
3Assuming Python platform Linux
4Searching for source files
5Found 1 source file
6pyright 1.1.299
70 errors, 0 warnings, 0 informations
8Completed in 0.917sec
Everything looks good as it should!
And if we change the types to invalid matrix shapes:
1import numpy as np
2from numpy.typing import NDArray
3from numpy_shape_typing.dot import dot
4from numpy_shape_typing.rand import rand_normal_matrix
5from numpy_shape_typing.types import ShapeNDType, Shape2D
7from typing import Literal
9ROWS = Literal[2]
10COLS = Literal[3]
11SLICES = Literal[4]
13# uh oh based on these types we can't do a valid dot product!
14A: NDArray[Shape2D[ROWS, COLS], np.float64] = rand_normal_matrix((2,3))
15B: NDArray[Shape2D[SLICES, COLS], np.float64] = rand_normal_matrix((4,3))
16C: NDArray[Shape2D[ROWS, COLS], np.float64] = dot(A, B)
17print(C)
And if we check with pyright
:
1py -m pyright ./bad_dot.py --lib
1No configuration file found.
2No pyproject.toml file found.
3Assuming Python platform Linux
4Searching for source files
5Found 1 source file
6pyright 1.1.299
7/mnt/bad_dot.py
8 /mnt/bad_dot.py:16:54 - error: Argument of type "NDArray[Shape2D[SLICES, COLS], float64]" cannot be assigned to parameter "x2" of type "GenericDType@dot" in function "dot"
9 Type "NDArray[Shape2D[ROWS, COLS], float64]" cannot be assigned to type "NDArray[Shape2D[SLICES, COLS], float64]" (reportGeneralTypeIssues)
101 error, 0 warnings, 0 informations
11Completed in 0.915sec
pyright
let's us know that the types we are using are incorrect shapes based
on np.dot
's type overloads we've specified.
Even moar numpy
shape typing!
The next function we are going to type is np.add
. The numpy
docs only show
two cases.
Two ND
array arguments of the same shape are added element wise
Two ND
array arguments that are not the same shape must be broadcastable to
a common shape
Covering the first case is easy, but the second case is much harder as we
would have to come up with a scheme to cover numpy
's array broadcasting
system. Currently python==3.11
's typing
doesn't have a generic way to
cover all the broadcasting rules. (If you know of a way let me know!)
However if we scope down the second case to only two dimensions, we can cover
all the array broadcasting rules with a few overloads:
1from typing import overload
3import numpy as np
4from numpy.typing import NDArray
6from numpy_shape_typing.types import ONE, T1, T2, GenericDType, Shape1D, Shape2D, ShapeVarGen
9@overload
10def add(
11 x1: NDArray[Shape2D[T1, T2], GenericDType],
12 x2: NDArray[Shape1D[T2], GenericDType],
14) -> NDArray[Shape2D[T1, T2], GenericDType]:
15 ...
18@overload
19def add(
20 x1: NDArray[Shape1D[T2], GenericDType],
21 x2: NDArray[Shape2D[T1, T2], GenericDType],
23) -> NDArray[Shape2D[T1, T2], GenericDType]:
24 ...
27@overload
28def add(
29 x1: NDArray[Shape2D[T1, T2], GenericDType],
30 x2: NDArray[Shape1D[ONE], GenericDType],
32) -> NDArray[Shape2D[T1, T2], GenericDType]:
33 ...
36@overload
37def add(
38 x1: NDArray[Shape1D[ONE], GenericDType],
39 x2: NDArray[Shape2D[T1, T2], GenericDType],
41) -> NDArray[Shape2D[T1, T2], GenericDType]:
42 ...
45@overload
46def add(
47 x1: NDArray[Shape2D[T1, T2], GenericDType],
48 x2: NDArray[Shape2D[T1, ONE], GenericDType],
50) -> NDArray[Shape2D[T1, T2], GenericDType]:
51 ...
54@overload
55def add(
56 x1: NDArray[Shape2D[T1, T2], GenericDType],
57 x2: NDArray[Shape2D[ONE, T2], GenericDType],
59) -> NDArray[Shape2D[T1, T2], GenericDType]:
60 ...
63@overload
64def add(
65 x1: NDArray[Shape2D[T1, ONE], GenericDType],
66 x2: NDArray[Shape2D[T1, T2], GenericDType],
68) -> NDArray[Shape2D[T1, T2], GenericDType]:
69 ...
72@overload
73def add(
74 x1: NDArray[Shape2D[ONE, T2], GenericDType],
75 x2: NDArray[Shape2D[T1, T2], GenericDType],
77) -> NDArray[Shape2D[T1, T2], GenericDType]:
78 ...
81@overload
82def add(
83 x1: GenericDType,
84 x2: NDArray[Shape2D[T1, T2], GenericDType],
86) -> NDArray[Shape2D[T1, T2], GenericDType]:
87 ...
90@overload
91def add(
92 x1: NDArray[Shape2D[T1, T2], GenericDType],
93 x2: GenericDType,
95) -> NDArray[Shape2D[T1, T2], GenericDType]:
96 ...
99@overload
100def add(
101 x1: NDArray[*ShapeVarGen, GenericDType],
102 x2: NDArray[*ShapeVarGen, GenericDType],
103 /,
104) -> NDArray[*ShapeVarGen, GenericDType]:
105 ...
108def add(x1, x2):
109 return np.add(x1, x2)
Using these overloads, here is how we would catch unexpected array broadcasts
(similar to the one from our original Linear
example).
1from typing import Literal
3import numpy as np
4from numpy.typing import NDArray
6from numpy_shape_typing.add import add
7from numpy_shape_typing.dot import dot
8from numpy_shape_typing.rand import rand_normal_matrix
9from numpy_shape_typing.types import ONE, Shape1D, Shape2D
11COLS = Literal[4]
12A: NDArray[Shape2D[COLS, COLS], np.float64] = rand_normal_matrix((4, 4))
13B: NDArray[Shape2D[ONE, COLS], np.float64] = rand_normal_matrix((1, 4))
14C: NDArray[Shape2D[ONE, COLS], np.float64] = add(A, B)
15print(C)
In the example above, our output is a 4x4
matrix, but what we want from our
types is an output shape of 4x1
. Let's see what pyright
says
1py -m pyright unnexpected_broadcast.py --lib
1No configuration file found.
2No pyproject.toml file found.
3Assuming Python platform Linux
4Searching for source files
5Found 1 source file
6pyright 1.1.299
7/mnt/unnexpected_broadcast.py
8 /mnt/unnexpected_broadcast.py:14:50 - error: Argument of type "NDArray[Shape2D[COLS, COLS], float64]" cannot be assigned to parameter "x1" of type "NDArray[*ShapeVarGen@add, GenericDType@add]" in function "add"
9 "NDArray[Shape2D[COLS, COLS], float64]" is incompatible with "NDArray[Shape2D[ONE, COLS], float64]"
10 TypeVar "_ShapeType@ndarray" is invariant
11 "*tuple[Shape2D[COLS, COLS]]" is incompatible with "*tuple[Shape2D[ONE, COLS]]"
12 Tuple entry 1 is incorrect type
13 "Shape2D[COLS, COLS]" is incompatible with "Shape2D[ONE, COLS]" (reportGeneralTypeIssues)
141 error, 0 warnings, 0 informations
15Completed in 2.73sec
pyright
informs us that our shapes are off and that we got broadcasted to a
4x4
! Huzzah shape typing!
Hitting the limitations of shape typing 😿
The last function we will type to finish of our Linear
example is
np.ravel
. However this is where we start hitting some major limitations of
shape typing as they exist today in python
and numpy
.
From the numpy docs on np.ravel
the only case we need to cover is that any
ND
array gets collapsed into a 1D
array of size of the total number of
elements. Luckily all the information to compute the final 1D
size is just
the product of all the input dimension sizes.
Ideally we would try to write code that looks something like this:
1ShapeVarGen = TypeVarTuple("ShapeVarGen")
3@overload
4def ravel(
5 arr: NDArray[Shape[*ShapeVarGen], GenericDType]
6) -> NDArray[Shape1D[Product[*ShapeVarGen]], GenericDType]:
But unfortunately python
's typing
package currently doesn't have a notion
of a Product
type that provides a way to do algebraic typing.
However for the sake of completion we can fake it!
If we scope down from a generic ND
typing of np.ravel
to support up to two
dimensions and limit the size of the output dimension to some maximum number,
we can overload all the possible factors that multiply to the output dimension
size. We would effectively be typing a multiplication table 😆, but it will
work and get us to a "partially" typed np.ravel
.
Here's how we can do it.
First we create a bunch of Literal
types (our factors):
1ZERO = Literal[0]
2ONE = Literal[1]
3TWO = Literal[2]
4THREE = Literal[3]
5FOUR = Literal[4]
Then we define "multiply" types for factor pairs of numbers:
1SHAPE_2D_MUL_TO_ONE = TypeVar(
2 "SHAPE_2D_MUL_TO_ONE",
3 bound=Shape2D[Literal[ONE], Literal[ONE]],
5SHAPE_2D_MUL_TO_TWO = TypeVar(
6 "SHAPE_2D_MUL_TO_TWO",
7 bound=Union[Shape2D[Literal[ONE], Literal[TWO]], Shape2D[Literal[TWO], Literal[ONE]]],
9SHAPE_2D_MUL_TO_THREE = TypeVar(
10 "SHAPE_2D_MUL_TO_THREE",
11 bound=Union[Shape2D[Literal[ONE], Literal[THREE]], Shape2D[Literal[THREE], Literal[ONE]]],
13SHAPE_2D_MUL_TO_FOUR = TypeVar(
14 "SHAPE_2D_MUL_TO_FOUR",
15 bound=Union[
16 Shape2D[Literal[ONE], Literal[FOUR]],
17 Shape2D[Literal[TWO], Literal[TWO]],
18 Shape2D[Literal[FOUR], Literal[ONE]],
Then lastly we wire these types up into individual ravel
overloads (and
cover a few generic ones while we're at it):
1@overload
2def ravel(arr: NDArray[SHAPE_2D_MUL_TO_ONE, GenericDType]) -> NDArray[Shape1D[ONE], GenericDType]:
6@overload
7def ravel(arr: NDArray[SHAPE_2D_MUL_TO_TWO, GenericDType]) -> NDArray[Shape1D[TWO], GenericDType]:
11@overload
12def ravel(arr: NDArray[SHAPE_2D_MUL_TO_THREE, GenericDType]) -> NDArray[Shape1D[THREE], GenericDType]:
13 ...
16@overload
17def ravel(arr: NDArray[SHAPE_2D_MUL_TO_FOUR, GenericDType]) -> NDArray[Shape1D[FOUR], GenericDType]:
18 ...
20@overload
21def ravel(arr: NDArray[Shape2D[T1, ONE], GenericDType]) -> NDArray[Shape1D[T1], GenericDType]:
22 ...
25@overload
26def ravel(arr: NDArray[Shape2D[ONE, T1], GenericDType]) -> NDArray[Shape1D[T1], GenericDType]:
27 ...
30@overload
31def ravel(arr: NDArray[Shape1D[T1], GenericDType]) -> NDArray[Shape1D[T1], GenericDType]:
32 ...
Now we can rinse and repeat for as many numbers as we like!
Here is how we'd use this typing to catch a shape type error with ravel
:
1import numpy as np
2from numpy.typing import NDArray
4from numpy_shape_typing.rand import rand_normal_matrix
5from numpy_shape_typing.ravel import ravel
6from numpy_shape_typing.types import FOUR, SEVEN, TWO, Shape1D, Shape2D
8A: NDArray[Shape2D[TWO, FOUR], np.float64] = rand_normal_matrix((2, 4))
9B: NDArray[Shape1D[SEVEN], np.float64] = ravel(A)
10print(B)
1py -m pyright raveling.py --lib
1No configuration file found.
2No pyproject.toml file found.
3Assuming Python platform Linux
4Searching for source files
5Found 1 source file
6pyright 1.1.299
7/mnt/raveling.py
8 /mnt/raveling.py:9:42 - error: Expression of type "NDArray[Shape1D[EIGHT], float64]" cannot be assigned to declared type "NDArray[Shape1D[SEVEN], float64]"
9 "NDArray[Shape1D[EIGHT], float64]" is incompatible with "NDArray[Shape1D[SEVEN], float64]"
10 TypeVar "_ShapeType@ndarray" is invariant
11 "*tuple[Shape1D[EIGHT]]" is incompatible with "*tuple[Shape1D[SEVEN]]"
12 Tuple entry 1 is incorrect type
13 "Shape1D[EIGHT]" is incompatible with "Shape1D[SEVEN]" (reportGeneralTypeIssues)
141 error, 0 warnings, 0 informations
15Completed in 0.925sec
Putting it all together
So far we've gone through typing a small subset of numpy
's functions
(np.random.standard_normal
, np.dot
, np.add
, and np.ravel
in all).
Now we can chain these typed functions together to form a typed Linear
implementation like so:
1from typing import Literal
3import numpy as np
4from numpy.typing import NDArray
6from numpy_shape_typing.add import add
7from numpy_shape_typing.dot import dot
8from numpy_shape_typing.rand import rand_normal_matrix
9from numpy_shape_typing.ravel import ravel
10from numpy_shape_typing.types import ONE, T1, T2, GenericDType, Shape1D, Shape2D
13def Linear(
14 A: NDArray[Shape2D[T1, T2], GenericDType],
15 x: NDArray[Shape2D[T2, ONE], GenericDType],
16 b: NDArray[Shape2D[T1, ONE], GenericDType],
17) -> NDArray[Shape1D[T1], GenericDType]:
18 Ax = dot(A, x)
19 Axb = add(Ax, b)
20 return ravel(Axb)
23IN_DIM = Literal[3]
24in_dim: IN_DIM = 3
26OUT_DIM = Literal[4]
27out_dim: OUT_DIM = 4
29# bad type >:(
30BAD_OUT_DIM = Literal[5]
32A: NDArray[Shape2D[OUT_DIM, IN_DIM], np.float64] = rand_normal_matrix((out_dim, in_dim))
33x: NDArray[Shape2D[IN_DIM, ONE], np.float64] = rand_normal_matrix((in_dim, 1))
34b: NDArray[Shape2D[OUT_DIM, ONE], np.float64] = rand_normal_matrix((out_dim, 1))
36# this is a bad type!
37y: NDArray[Shape1D[BAD_OUT_DIM], np.float64] = Linear(A, x, b)
I've included an intentional type error which should be caught by pyright
like so:
1py -m pyright linear_type_bad.py --lib
1No configuration file found.
2No pyproject.toml file found.
3Assuming Python platform Linux
4Searching for source files
5Found 1 source file
6pyright 1.1.299
7/mnt/linear_type_bad.py
8 /mnt/linear_type_bad.py:37:55 - error: Argument of type "NDArray[Shape2D[OUT_DIM, IN_DIM], float64]" cannot be assigned to parameter "A" of type "NDArray[Shape2D[T1@Linear, T2@Linear], GenericDType@Linear]" in function "Linear"
9 "NDArray[Shape2D[OUT_DIM, IN_DIM], float64]" is incompatible with "NDArray[Shape2D[BAD_OUT_DIM, IN_DIM], float64]"
10 TypeVar "_ShapeType@ndarray" is invariant
11 "*tuple[Shape2D[OUT_DIM, IN_DIM]]" is incompatible with "*tuple[Shape2D[BAD_OUT_DIM, IN_DIM]]"
12 Tuple entry 1 is incorrect type
13 "Shape2D[OUT_DIM, IN_DIM]" is incompatible with "Shape2D[BAD_OUT_DIM, IN_DIM]" (reportGeneralTypeIssues)
14 /mnt/linear_type_bad.py:37:61 - error: Argument of type "NDArray[Shape2D[OUT_DIM, ONE], float64]" cannot be assigned to parameter "b" of type "NDArray[Shape2D[T1@Linear, ONE], GenericDType@Linear]" in function "Linear"
15 "NDArray[Shape2D[OUT_DIM, ONE], float64]" is incompatible with "NDArray[Shape2D[BAD_OUT_DIM, ONE], float64]"
16 TypeVar "_ShapeType@ndarray" is invariant
17 "*tuple[Shape2D[OUT_DIM, ONE]]" is incompatible with "*tuple[Shape2D[BAD_OUT_DIM, ONE]]"
18 Tuple entry 1 is incorrect type
19 "Shape2D[OUT_DIM, ONE]" is incompatible with "Shape2D[BAD_OUT_DIM, ONE]" (reportGeneralTypeIssues)
202 errors, 0 warnings, 0 informations
21Completed in 8.132sec
And huzzah again! pyright
has caught the shape type error!
And now we can fix this shape error by changing BAD_OUT_DIM
to the correct
output dimension size.
1from typing import Literal
3import numpy as np
4from numpy.typing import NDArray
6from numpy_shape_typing.add import add
7from numpy_shape_typing.dot import dot
8from numpy_shape_typing.rand import rand_normal_matrix
9from numpy_shape_typing.ravel import ravel
10from numpy_shape_typing.types import ONE, T1, T2, GenericDType, Shape1D, Shape2D
13def Linear(
14 A: NDArray[Shape2D[T1, T2], GenericDType],
15 x: NDArray[Shape2D[T2, ONE], GenericDType],
16 b: NDArray[Shape2D[T1, ONE], GenericDType],
17) -> NDArray[Shape1D[T1], GenericDType]:
18 """
19 Args:
20 A: ndarray of shape (M x N)
21 x: ndarray of shape (N x 1)
22 b: ndarray of shape (M x 1)
24 Returns:
25 Linear output ndarray of shape (M)
26 """
27 Ax = dot(A, x)
28 Axb = add(Ax, b)
29 return ravel(Axb)
32IN_DIM = Literal[3]
33in_dim: IN_DIM = 3
35OUT_DIM = Literal[4]
36out_dim: OUT_DIM = 4
38A: NDArray[Shape2D[OUT_DIM, IN_DIM], np.float64] = rand_normal_matrix((out_dim, in_dim))
39x: NDArray[Shape2D[IN_DIM, ONE], np.float64] = rand_normal_matrix((in_dim, 1))
40b: NDArray[Shape2D[OUT_DIM, ONE], np.float64] = rand_normal_matrix((out_dim, 1))
41y: NDArray[Shape1D[OUT_DIM], np.float64] = Linear(A, x, b)
And if we check with pyright
.
1py -m pyright linear_type_good.py --lib
1No configuration file found.
2No pyproject.toml file found.
3Assuming Python platform Linux
4Searching for source files
5Found 1 source file
6pyright 1.1.299
70 errors, 0 warnings, 0 informations
8Completed in 8.131sec
pyright
tells us that our types are consistent!
What's next?
You tell me! Many open source scientific computing libraries have GitHub issues
about shape typing such as:
So it's well recognized as a desirable feature. Some of the major technical
hurdles we still need to overcome are:
Once these hurdles are overcome I don't see any blockers stopping projects
like numpy
from being fully shape typed.
This post and accompanying repo is just a sample form of what shape typing
might become. With future PEPs and work on the python
type hinting system,
we'll hopefully make our code incrementally safer.
Thanks for reading! (っ◔◡◔)っ ♥