Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Dec 10, 2024
1 parent b22bbfb commit 8dc378e
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 154 deletions.
158 changes: 17 additions & 141 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ crate-type = ["cdylib", "rlib"]
path = "src/lib.rs" # The source file of the target.

[dependencies]
pyo3 = { version = "0.21.2", features = ["extension-module"] }
numpy = "0.21"
pyo3 = { version = "0.23.3", features = ["extension-module"] }
numpy = "0.23"
25 changes: 14 additions & 11 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ impl Network {
}

#[getter]
pub fn get_node_trajectories<'py>(&self, py: Python<'py>) -> PyResult<&'py PyList> {
pub fn get_node_trajectories<'py>(&self, py: Python<'py>) -> PyResult<Py<PyList>> {
let py_list = PyList::empty(py);


Expand All @@ -210,27 +210,27 @@ impl Network {
if let Some(vector_node) = self.node_trajectories.vectors.get(node_idx) {
for (vector_key, vector_value) in vector_node {
// Create a new Python dictionary
py_dict.set_item(vector_key, PyArray::from_vec2_bound(py, &vector_value).unwrap()).expect("Failed to set item in PyDict");
py_dict.set_item(vector_key, PyArray::from_vec2(py, &vector_value).unwrap()).expect("Failed to set item in PyDict");
}
}
py_list.append(py_dict)?;
}

// Create a PyList from Vec<usize>
Ok(py_list)
Ok(py_list.into())
}

#[getter]
pub fn get_inputs<'py>(&self, py: Python<'py>) -> PyResult<&'py PyList> {
pub fn get_inputs<'py>(&self, py: Python<'py>) -> PyResult<Py<PyList>> {
let py_list = PyList::new(py, &self.inputs); // Create a PyList from Vec<usize>
Ok(py_list)
Ok(py_list.into())
}

#[getter]
pub fn get_edges<'py>(&self, py: Python<'py>) -> PyResult<&'py PyList> {
pub fn get_edges<'py>(&self, py: Python<'py>) -> PyResult<Py<PyList>> {
// Create a new Python list
let py_list = PyList::empty(py);

// Convert each struct in the Vec to a Python object and add to PyList
for i in 0..self.edges.len() {
// Create a new Python dictionary for each MyStruct
Expand All @@ -239,15 +239,18 @@ impl Network {
py_dict.set_item("value_children", &self.edges[&i].value_children)?;
py_dict.set_item("volatility_parents", &self.edges[&i].volatility_parents)?;
py_dict.set_item("volatility_children", &self.edges[&i].volatility_children)?;

// Add the dictionary to the list
py_list.append(py_dict)?;
}
Ok(py_list)

// Return the PyList object directly
Ok(py_list.into())
}


#[getter]
pub fn get_update_sequence<'py>(&self, py: Python<'py>) -> PyResult<&'py PyList> {
pub fn get_update_sequence<'py>(&self, py: Python<'py>) -> PyResult<Py<PyList>> {

let func_map = get_func_map();
let py_list = PyList::empty(py);
Expand All @@ -272,7 +275,7 @@ impl Network {
// Append the Python tuple to the Python list
py_list.append(py_tuple)?;
}
Ok(py_list)
Ok(py_list.into())
}
}

Expand Down

0 comments on commit 8dc378e

Please sign in to comment.