File tree Expand file tree Collapse file tree 1 file changed +16
-5
lines changed
Filter options
numpy/_core/src/multiarray Expand file tree Collapse file tree 1 file changed +16
-5
lines changed
Original file line number Diff line number Diff line change @@ -897,18 +897,29 @@ can_cast_pyscalar_scalar_to(
897
897
}
898
898
899
899
/*
900
- * For all other cases we use the default dtype.
900
+ * For all other cases we need to make a bit of a dance to find the cast
901
+ * safety. We do so by finding the descriptor for the "scalar" (without
902
+ * a value; for parametric user dtypes a value may be needed eventually).
901
903
*/
902
- PyArray_Descr * from ;
904
+ PyArray_DTypeMeta * from_DType ;
905
+ PyArray_Descr * default_dtype ;
903
906
if (flags & NPY_ARRAY_WAS_PYTHON_INT ) {
904
- from = PyArray_DescrFromType (NPY_LONG );
907
+ default_dtype = PyArray_DescrNewFromType (NPY_INTP );
908
+ from_DType = & PyArray_PyLongDType ;
905
909
}
906
910
else if (flags & NPY_ARRAY_WAS_PYTHON_FLOAT ) {
907
- from = PyArray_DescrFromType (NPY_DOUBLE );
911
+ default_dtype = PyArray_DescrNewFromType (NPY_FLOAT64 );
912
+ from_DType = & PyArray_PyFloatDType ;
908
913
}
909
914
else {
910
- from = PyArray_DescrFromType (NPY_CDOUBLE );
915
+ default_dtype = PyArray_DescrNewFromType (NPY_COMPLEX128 );
916
+ from_DType = & PyArray_PyComplexDType ;
911
917
}
918
+
919
+ PyArray_Descr * from = npy_find_descr_for_scalar (
920
+ NULL , default_dtype , from_DType , NPY_DTYPE (to ));
921
+ Py_DECREF (default_dtype );
922
+
912
923
int res = PyArray_CanCastTypeTo (from , to , casting );
913
924
Py_DECREF (from );
914
925
return res ;
You can’t perform that action at this time.
0 commit comments