neutralts/bif/
exec_python.rs

1// pyo3 = { version = "0.26.0", features = [] }
2
3use crate::{bif::BifError, Value};
4use pyo3::prelude::*;
5use pyo3::types::{PyList, PyModule};
6use std::path::Path;
7use std::env;
8use std::process::Command;
9
10pub struct PythonExecutor;
11
12impl PythonExecutor {
13    pub(crate) fn exec_py(
14        file: &str,
15        params_value: &Value,
16        callback_name: &str,
17        schema: Option<&Value>,
18        venv_path: Option<&str>,
19    ) -> Result<Value, BifError> {
20        if let Some(venv) = venv_path {
21            Self::setup_venv(venv)?;
22        }
23
24        Python::initialize();
25
26        Python::attach(|py| -> PyResult<Value> {
27            let params = Self::prepare_python_params(py, params_value)?;
28            Self::setup_python_path(py, file)?;
29            Self::execute_python_callback(py, file, callback_name, params, schema)
30        })
31        .map_err(|e| BifError {
32            msg: format!(
33                "Error executing callback function '{}': {}",
34                callback_name, e
35            ),
36            name: "python_callback".to_string(),
37            file: file.to_string(),
38            src: file.to_string(),
39        })
40    }
41
42    fn setup_venv(venv_path: &str) -> Result<(), BifError> {
43        let path = Path::new(venv_path);
44        if !path.exists() {
45            return Err(BifError {
46                msg: format!("Venv path '{}' does not exist", venv_path),
47                name: "venv_error".to_string(),
48                file: "".to_string(),
49                src: "".to_string(),
50            });
51        }
52
53        let python_executable = if cfg!(unix) {
54            format!("{}/bin/python", venv_path)
55        } else {
56            format!("{}\\Scripts\\python.exe", venv_path)
57        };
58
59        if !Path::new(&python_executable).exists() {
60            return Err(BifError {
61                msg: format!("Python executable not found: {}", python_executable),
62                name: "venv_error".to_string(),
63                file: "".to_string(),
64                src: "".to_string(),
65            });
66        }
67
68        env::set_var("PYTHON_EXECUTABLE", &python_executable);
69        env::set_var("VIRTUAL_ENV", venv_path);
70
71        let output = Command::new(&python_executable)
72            .arg("-c")
73            .arg("import sys; print(sys.prefix); print(':'.join(sys.path))")
74            .output()
75            .map_err(|e| BifError {
76                msg: format!("Failed to get Python path info: {}", e),
77                name: "venv_error".to_string(),
78                file: "".to_string(),
79                src: "".to_string(),
80            })?;
81
82        if output.status.success() {
83            let output_str = String::from_utf8_lossy(&output.stdout);
84            let lines: Vec<&str> = output_str.trim().split('\n').collect();
85            if lines.len() >= 2 {
86                env::set_var("PYTHONHOME", lines[0]);
87                env::set_var("PYTHONPATH", lines[1]);
88            }
89        }
90
91        Ok(())
92    }
93
94    fn prepare_python_params<'py>(py: Python<'py>, params_value: &Value) -> PyResult<Py<PyAny>> {
95        let params_json = serde_json::to_string(params_value).map_err(|e| {
96            pyo3::exceptions::PyValueError::new_err(format!("Failed to serialize params: {}", e))
97        })?;
98        let json_mod = PyModule::import(py, "json")?;
99        let loads = json_mod.getattr("loads")?;
100        let py_obj = loads.call1((params_json,))?;
101        let py_object: Py<PyAny> = py_obj.extract()?;
102        Ok(py_object)
103    }
104
105    fn setup_python_path(py: Python, file: &str) -> PyResult<()> {
106        let dir_path = Path::new(file).parent().unwrap_or_else(|| Path::new("."));
107        let sys = PyModule::import(py, "sys")?;
108        let path_attr = sys.getattr("path")?;
109        let path = path_attr.downcast::<PyList>()?;
110        if let Some(dir_str) = dir_path.to_str() {
111            path.append(dir_str)?;
112        } else {
113            return Err(pyo3::exceptions::PyValueError::new_err(
114                "Invalid directory path encoding",
115            ));
116        }
117        Ok(())
118    }
119
120    fn execute_python_callback<'py>(
121        py: Python<'py>,
122        file: &str,
123        callback_name: &str,
124        params: Py<PyAny>,
125        schema: Option<&Value>,
126    ) -> PyResult<Value> {
127        let module_name = Self::extract_module_name(file)?;
128        let module = PyModule::import(py, &module_name)?;
129
130        if let Some(schema_value) = schema {
131            let schema_py = Self::prepare_python_params(py, schema_value)?;
132            module.setattr("__NEUTRAL_SCHEMA__", schema_py)?;
133        }
134
135        let callback_func = module.getattr(callback_name).map_err(|_| {
136            pyo3::exceptions::PyAttributeError::new_err(format!(
137                "Module does not have function '{}'",
138                callback_name
139            ))
140        })?;
141        let result_any = callback_func.call1((params,))?;
142        let result_obj: Py<PyAny> = result_any.extract()?;
143        Self::convert_python_result_to_json(py, result_obj)
144    }
145
146    fn extract_module_name(file: &str) -> PyResult<String> {
147        Path::new(file)
148            .file_stem()
149            .and_then(|s| s.to_str())
150            .map(|s| s.to_string())
151            .ok_or_else(|| {
152                pyo3::exceptions::PyValueError::new_err(
153                    "Could not extract module name from file path",
154                )
155            })
156    }
157
158    fn convert_python_result_to_json<'py>(py: Python<'py>, result: Py<PyAny>) -> PyResult<Value> {
159        let json_module = PyModule::import(py, "json")?;
160        let json_dumps = json_module.getattr("dumps")?;
161        let json_string: String = json_dumps.call1((result,))?.extract()?;
162        serde_json::from_str(&json_string).map_err(|e| {
163            pyo3::exceptions::PyValueError::new_err(format!("Error parsing JSON: {}", e))
164        })
165    }
166}