diff --git a/docs/src/juliacall-reference.md b/docs/src/juliacall-reference.md index ef5e5e5c..319da77f 100644 --- a/docs/src/juliacall-reference.md +++ b/docs/src/juliacall-reference.md @@ -202,9 +202,9 @@ jl.Vector[jl.Int]() ``` Some Julia types can be converted to corresponding numpy dtypes like `numpy.dtype(jl.Int)`. -Supports primitive types: `Bool`, `IntXX`, `UIntXX`, `FloatXX`, `ComplexFXX`, -`NumpyDates.InlineDateTime64{unit}` and `NumpyDates.InlineTimeDelta64{unit}`. Also -supports tuples, named tuples and structs of these. +Supports `Bool`, `IntXX`, `UIntXX`, `FloatXX`, `ComplexFXX`, +`NumpyDates.InlineDateTime64{unit}` and `NumpyDates.InlineTimeDelta64{unit}`, plus +`Tuple`s and `NamedTuple`s of these. ````` `````@customdoc diff --git a/src/JlWrap/array.jl b/src/JlWrap/array.jl index 299aedcc..6ed25176 100644 --- a/src/JlWrap/array.jl +++ b/src/JlWrap/array.jl @@ -187,7 +187,7 @@ pybufferformat(::Type{T}) where {T} = T == Complex{Cdouble} ? "Zd" : T == Bool ? "?" : T == Ptr{Cvoid} ? "P" : - if isstructtype(T) && isconcretetype(T) && allocatedinline(T) + if (T <: Union{Tuple,NamedTuple}) && isstructtype(T) && isconcretetype(T) && allocatedinline(T) n = fieldcount(T) flds = [] for i = 1:n @@ -234,7 +234,7 @@ pyjlarray_isarrayabletype(::Type{NamedTuple{names,types}}) where {names,types} = const PYTYPESTRDESCR = IdDict{Type,Tuple{String,Py}}() -pytypestrdescr(::Type{T}) where {T} = +function pytypestrdescr(::Type{T}) where {T} get!(PYTYPESTRDESCR, T) do c = Utils.islittleendian() ? '<' : '>' if T == Bool @@ -275,7 +275,7 @@ pytypestrdescr(::Type{T}) where {T} = u == NumpyDates.UNBOUND_UNITS ? "" : m == 1 ? "[$(Symbol(u))]" : "[$(m)$(Symbol(u))]" ("$(c)$(tc)8$(us)", PyNULL) - elseif isstructtype(T) && isconcretetype(T) && Base.allocatedinline(T) + elseif (T <: Union{Tuple,NamedTuple}) && isstructtype(T) && isconcretetype(T) && Base.allocatedinline(T) n = fieldcount(T) flds = [] for i = 1:n @@ -298,6 +298,7 @@ pytypestrdescr(::Type{T}) where {T} = ("", PyNULL) end end +end pyjlarray_array__array(x::AbstractArray) = x isa Array ? Py(nothing) : pyjl(Array(x)) pyjlarray_array__pyobjectarray(x::AbstractArray) = pyjl(PyObjectArray(x)) diff --git a/src/JlWrap/type.jl b/src/JlWrap/type.jl index 00bbf058..669111ac 100644 --- a/src/JlWrap/type.jl +++ b/src/JlWrap/type.jl @@ -11,22 +11,49 @@ function pyjltype_getitem(self::Type, k_) end end +const PYNUMPYDTYPE = IdDict{Type,Py}() + function pyjltype_numpy_dtype(self::Type) - typestr, descr = pytypestrdescr(self) - if isempty(typestr) - errset(pybuiltins.AttributeError, "__numpy_dtype__") - return PyNULL + ans = get!(PYNUMPYDTYPE, self) do + typestr, descr = pytypestrdescr(self) + # unsupported type + if typestr == "" + return PyNULL + end + np = pyimport("numpy") + # simple scalar type + if pyisnull(descr) + return np.dtype(typestr) + end + # We could juse use np.dtype(descr), but when there is padding, np.dtype(descr) + # changes the names of the padding fields from "" to "f{N}". Using this other + # dtype constructor avoids this issue and preserves the invariant: + # np.dtype(eltype(array)) == np.array(array).dtype + names = [] + formats = [] + offsets = [] + for i = 1:fieldcount(self) + nm = fieldname(self, i) + push!(names, nm isa Integer ? "f$(nm-1)" : String(nm)) + ts, ds = pytypestrdescr(fieldtype(self, i)) + push!(formats, pyisnull(ds) ? ts : ds) + push!(offsets, fieldoffset(self, i)) + end + return np.dtype( + pydict( + names = pylist(names), + formats = pylist(formats), + offsets = pylist(offsets), + itemsize = sizeof(self), + ), + ) end - np = pyimport("numpy") - if pyisnull(descr) - return np.dtype(typestr) - else - return np.dtype(descr) + if pyisnull(ans) + errset(pybuiltins.AttributeError, "__numpy_dtype__") end + return ans end -pyjl_handle_error_type(::typeof(pyjltype_numpy_dtype), x, exc) = pybuiltins.AttributeError - function init_type() jl = pyjuliacallmodule pybuiltins.exec( diff --git a/test/JlWrap.jl b/test/JlWrap.jl index 818f83de..3095693d 100644 --- a/test/JlWrap.jl +++ b/test/JlWrap.jl @@ -510,15 +510,18 @@ end (Tuple{Int32, Int32}, pylist([("f0", "int32"), ("f1", "int32")])), (@NamedTuple{}, pylist()), (@NamedTuple{x::Int32, y::Int32}, pylist([("x", "int32"), ("y", "int32")])), - (Pair{Int32, Int32}, pylist([("first", "int32"), ("second", "int32")])), ] @test pyeq(Bool, pygetattr(pyjl(t), "__numpy_dtype__"), np.dtype(d)) - @test pyeq(Bool, np.dtype(pyjl(t)), np.dtype(d)) + @test pyeq(Bool, np.dtype(t), np.dtype(d)) + # test the invariant np.dtype(eltype(array)) == np.array(array).dtype + @test isequal(np.dtype(t), np.array(t[]).dtype) end # unsupported cases @testset "$t -> AttributeError" for t in [ - # non-primitives or mutables + # structs / mutables + Pair, + Pair{Int,Int}, String, Vector{Int}, # pointers @@ -526,6 +529,9 @@ end Ptr{Int}, # PyPtr specifically should NOT be interpreted as np.dtype("O") PythonCall.C.PyPtr, + # tuples containing illegal things + Tuple{String}, + Tuple{Pair{Int,Int}}, ] err = try pygetattr(pyjl(t), "__numpy_dtype__")