From 7357d680415a8791509015ebbae3207b99e22610 Mon Sep 17 00:00:00 2001 From: Christopher Rowley Date: Fri, 23 Jan 2026 21:41:08 +0000 Subject: [PATCH] Adds tests (and fixes) for the dtype invariant. That is, np.array(array).dtype == np.dtype(eltype(array)). For this to hold, we needed to restrict to only creating dtypes for primitives, tuples and named tuples. We removed support for arbitrary structs (which are not supported by our implementation of the array interface and buffer protocol). We also worked around a feature/bug/quirk of numpy in that if you do numpy.dtype(descr) where descr is a list of (name,type) field descriptors of a struct, then the dtype you get is not the same as the dtype of an array constructed from something whose array interface has that same descr. In particular, if any item in descr is struct padding like ("", "|V4"), then on conversion to a dtype the name is replaced with e.g. "f2". Going the array route, the padding gets ignored and does not feature in the resulting dtype. The fix here is to compute a different representation of the same information for the dtype - namely the dict of names, types and offsets way. --- docs/src/juliacall-reference.md | 6 ++-- src/JlWrap/array.jl | 7 +++-- src/JlWrap/type.jl | 49 +++++++++++++++++++++++++-------- test/JlWrap.jl | 12 ++++++-- 4 files changed, 54 insertions(+), 20 deletions(-) diff --git a/docs/src/juliacall-reference.md b/docs/src/juliacall-reference.md index ef5e5e5c..319da77f 100644 --- a/docs/src/juliacall-reference.md +++ b/docs/src/juliacall-reference.md @@ -202,9 +202,9 @@ jl.Vector[jl.Int]() ``` Some Julia types can be converted to corresponding numpy dtypes like `numpy.dtype(jl.Int)`. -Supports primitive types: `Bool`, `IntXX`, `UIntXX`, `FloatXX`, `ComplexFXX`, -`NumpyDates.InlineDateTime64{unit}` and `NumpyDates.InlineTimeDelta64{unit}`. Also -supports tuples, named tuples and structs of these. +Supports `Bool`, `IntXX`, `UIntXX`, `FloatXX`, `ComplexFXX`, +`NumpyDates.InlineDateTime64{unit}` and `NumpyDates.InlineTimeDelta64{unit}`, plus +`Tuple`s and `NamedTuple`s of these. ````` `````@customdoc diff --git a/src/JlWrap/array.jl b/src/JlWrap/array.jl index 299aedcc..6ed25176 100644 --- a/src/JlWrap/array.jl +++ b/src/JlWrap/array.jl @@ -187,7 +187,7 @@ pybufferformat(::Type{T}) where {T} = T == Complex{Cdouble} ? "Zd" : T == Bool ? "?" : T == Ptr{Cvoid} ? "P" : - if isstructtype(T) && isconcretetype(T) && allocatedinline(T) + if (T <: Union{Tuple,NamedTuple}) && isstructtype(T) && isconcretetype(T) && allocatedinline(T) n = fieldcount(T) flds = [] for i = 1:n @@ -234,7 +234,7 @@ pyjlarray_isarrayabletype(::Type{NamedTuple{names,types}}) where {names,types} = const PYTYPESTRDESCR = IdDict{Type,Tuple{String,Py}}() -pytypestrdescr(::Type{T}) where {T} = +function pytypestrdescr(::Type{T}) where {T} get!(PYTYPESTRDESCR, T) do c = Utils.islittleendian() ? '<' : '>' if T == Bool @@ -275,7 +275,7 @@ pytypestrdescr(::Type{T}) where {T} = u == NumpyDates.UNBOUND_UNITS ? "" : m == 1 ? "[$(Symbol(u))]" : "[$(m)$(Symbol(u))]" ("$(c)$(tc)8$(us)", PyNULL) - elseif isstructtype(T) && isconcretetype(T) && Base.allocatedinline(T) + elseif (T <: Union{Tuple,NamedTuple}) && isstructtype(T) && isconcretetype(T) && Base.allocatedinline(T) n = fieldcount(T) flds = [] for i = 1:n @@ -298,6 +298,7 @@ pytypestrdescr(::Type{T}) where {T} = ("", PyNULL) end end +end pyjlarray_array__array(x::AbstractArray) = x isa Array ? Py(nothing) : pyjl(Array(x)) pyjlarray_array__pyobjectarray(x::AbstractArray) = pyjl(PyObjectArray(x)) diff --git a/src/JlWrap/type.jl b/src/JlWrap/type.jl index 00bbf058..669111ac 100644 --- a/src/JlWrap/type.jl +++ b/src/JlWrap/type.jl @@ -11,22 +11,49 @@ function pyjltype_getitem(self::Type, k_) end end +const PYNUMPYDTYPE = IdDict{Type,Py}() + function pyjltype_numpy_dtype(self::Type) - typestr, descr = pytypestrdescr(self) - if isempty(typestr) - errset(pybuiltins.AttributeError, "__numpy_dtype__") - return PyNULL + ans = get!(PYNUMPYDTYPE, self) do + typestr, descr = pytypestrdescr(self) + # unsupported type + if typestr == "" + return PyNULL + end + np = pyimport("numpy") + # simple scalar type + if pyisnull(descr) + return np.dtype(typestr) + end + # We could juse use np.dtype(descr), but when there is padding, np.dtype(descr) + # changes the names of the padding fields from "" to "f{N}". Using this other + # dtype constructor avoids this issue and preserves the invariant: + # np.dtype(eltype(array)) == np.array(array).dtype + names = [] + formats = [] + offsets = [] + for i = 1:fieldcount(self) + nm = fieldname(self, i) + push!(names, nm isa Integer ? "f$(nm-1)" : String(nm)) + ts, ds = pytypestrdescr(fieldtype(self, i)) + push!(formats, pyisnull(ds) ? ts : ds) + push!(offsets, fieldoffset(self, i)) + end + return np.dtype( + pydict( + names = pylist(names), + formats = pylist(formats), + offsets = pylist(offsets), + itemsize = sizeof(self), + ), + ) end - np = pyimport("numpy") - if pyisnull(descr) - return np.dtype(typestr) - else - return np.dtype(descr) + if pyisnull(ans) + errset(pybuiltins.AttributeError, "__numpy_dtype__") end + return ans end -pyjl_handle_error_type(::typeof(pyjltype_numpy_dtype), x, exc) = pybuiltins.AttributeError - function init_type() jl = pyjuliacallmodule pybuiltins.exec( diff --git a/test/JlWrap.jl b/test/JlWrap.jl index 818f83de..3095693d 100644 --- a/test/JlWrap.jl +++ b/test/JlWrap.jl @@ -510,15 +510,18 @@ end (Tuple{Int32, Int32}, pylist([("f0", "int32"), ("f1", "int32")])), (@NamedTuple{}, pylist()), (@NamedTuple{x::Int32, y::Int32}, pylist([("x", "int32"), ("y", "int32")])), - (Pair{Int32, Int32}, pylist([("first", "int32"), ("second", "int32")])), ] @test pyeq(Bool, pygetattr(pyjl(t), "__numpy_dtype__"), np.dtype(d)) - @test pyeq(Bool, np.dtype(pyjl(t)), np.dtype(d)) + @test pyeq(Bool, np.dtype(t), np.dtype(d)) + # test the invariant np.dtype(eltype(array)) == np.array(array).dtype + @test isequal(np.dtype(t), np.array(t[]).dtype) end # unsupported cases @testset "$t -> AttributeError" for t in [ - # non-primitives or mutables + # structs / mutables + Pair, + Pair{Int,Int}, String, Vector{Int}, # pointers @@ -526,6 +529,9 @@ end Ptr{Int}, # PyPtr specifically should NOT be interpreted as np.dtype("O") PythonCall.C.PyPtr, + # tuples containing illegal things + Tuple{String}, + Tuple{Pair{Int,Int}}, ] err = try pygetattr(pyjl(t), "__numpy_dtype__")