diff --git a/data-raw/_generate_python_wrapper.sh b/data-raw/_generate_python_wrapper.sh index a8b2ee01a..f56b139fd 100644 --- a/data-raw/_generate_python_wrapper.sh +++ b/data-raw/_generate_python_wrapper.sh @@ -141,6 +141,32 @@ import numpy as np # Import the AMR R package amr_r = importr('AMR') +def convert_to_r(value): + """Convert Python lists/tuples to typed R vectors. + + rpy2's default_converter passes Python lists to R as R lists, not as + character/numeric vectors. This causes element-wise type-check functions + such as is.mic(), is.sir(), and is.disk() to return a logical vector + rather than a single logical, breaking R's scalar && operator. + + This helper converts Python lists and tuples to the appropriate R vector + type based on the element types, so R always receives a proper vector.""" + if isinstance(value, (list, tuple)): + if len(value) == 0: + return StrVector([]) + # bool must be checked before int because bool is a subclass of int + if all(isinstance(v, bool) for v in value): + return robjects.vectors.BoolVector(value) + if all(isinstance(v, int) for v in value): + return IntVector(value) + if all(isinstance(v, float) for v in value): + return FloatVector(value) + if all(isinstance(v, str) for v in value): + return StrVector(value) + # Mixed types: coerce all to string + return StrVector([str(v) for v in value]) + return value + def convert_to_python(r_output): # Check if it's a StrVector (R character vector) if isinstance(r_output, StrVector): @@ -166,10 +192,13 @@ def convert_to_python(r_output): return r_output def r_to_python(r_func): - """Decorator that runs an rpy2 function under a localconverter - and then applies convert_to_python to its output.""" + """Decorator that converts Python list/tuple inputs to typed R vectors, + runs the rpy2 function under a localconverter, and converts the output + to a Python type.""" @functools.wraps(r_func) def wrapper(*args, **kwargs): + args = tuple(convert_to_r(a) for a in args) + kwargs = {k: convert_to_r(v) for k, v in kwargs.items()} with localconverter(default_converter + numpy2ri.converter + pandas2ri.converter): return convert_to_python(r_func(*args, **kwargs)) return wrapper @@ -312,4 +341,3 @@ cd ../PythonPackage/AMR pip3 install build python3 -m build # python3 setup.py sdist bdist_wheel -