1
0
mirror of https://github.com/msberends/AMR.git synced 2026-03-30 17:35:58 +02:00

fix: convert Python lists to R vectors in wrapper to prevent R list coercion errors

Fixes #267. Python lists passed to R functions via rpy2 are received as
R lists, not R character/numeric vectors. This causes is.mic(), is.sir(),
is.disk() etc. to return length > 1 logicals, which break R's && operator.

Added convert_to_r() helper that maps Python list/tuple to the appropriate
typed R vector (StrVector, IntVector, FloatVector) based on element types.
The r_to_python decorator now applies this to all args and kwargs before
calling the R function.
This commit is contained in:
Matthijs Berends
2026-03-29 14:34:58 +02:00
parent 9c95aa455c
commit 6a7e8ce036

View File

@@ -141,6 +141,32 @@ import numpy as np
# Import the AMR R package # Import the AMR R package
amr_r = importr('AMR') amr_r = importr('AMR')
def convert_to_r(value):
"""Convert Python lists/tuples to typed R vectors.
rpy2's default_converter passes Python lists to R as R lists, not as
character/numeric vectors. This causes element-wise type-check functions
such as is.mic(), is.sir(), and is.disk() to return a logical vector
rather than a single logical, breaking R's scalar && operator.
This helper converts Python lists and tuples to the appropriate R vector
type based on the element types, so R always receives a proper vector."""
if isinstance(value, (list, tuple)):
if len(value) == 0:
return StrVector([])
# bool must be checked before int because bool is a subclass of int
if all(isinstance(v, bool) for v in value):
return robjects.vectors.BoolVector(value)
if all(isinstance(v, int) for v in value):
return IntVector(value)
if all(isinstance(v, float) for v in value):
return FloatVector(value)
if all(isinstance(v, str) for v in value):
return StrVector(value)
# Mixed types: coerce all to string
return StrVector([str(v) for v in value])
return value
def convert_to_python(r_output): def convert_to_python(r_output):
# Check if it's a StrVector (R character vector) # Check if it's a StrVector (R character vector)
if isinstance(r_output, StrVector): if isinstance(r_output, StrVector):
@@ -166,10 +192,13 @@ def convert_to_python(r_output):
return r_output return r_output
def r_to_python(r_func): def r_to_python(r_func):
"""Decorator that runs an rpy2 function under a localconverter """Decorator that converts Python list/tuple inputs to typed R vectors,
and then applies convert_to_python to its output.""" runs the rpy2 function under a localconverter, and converts the output
to a Python type."""
@functools.wraps(r_func) @functools.wraps(r_func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
args = tuple(convert_to_r(a) for a in args)
kwargs = {k: convert_to_r(v) for k, v in kwargs.items()}
with localconverter(default_converter + numpy2ri.converter + pandas2ri.converter): with localconverter(default_converter + numpy2ri.converter + pandas2ri.converter):
return convert_to_python(r_func(*args, **kwargs)) return convert_to_python(r_func(*args, **kwargs))
return wrapper return wrapper
@@ -312,4 +341,3 @@ cd ../PythonPackage/AMR
pip3 install build pip3 install build
python3 -m build python3 -m build
# python3 setup.py sdist bdist_wheel # python3 setup.py sdist bdist_wheel