diff --git a/Cargo.lock b/Cargo.lock index 1335239..e359619 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2176,6 +2176,17 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "pyo3_rayon" +version = "0.1.0" +dependencies = [ + "csv", + "pyo3", + "rayon", + "regex", + "serde", +] + [[package]] name = "quote" version = "1.0.35" diff --git a/Cargo.toml b/Cargo.toml index 13bb738..8ee0857 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,5 +13,6 @@ members = [ "src/rest_api_postgres/rust", "src/parallelism/rust", "src/pyo3_mock_data", + "src/pyo3_rayon", ] resolver = "2" diff --git a/src/pyo3_rayon/Cargo.toml b/src/pyo3_rayon/Cargo.toml new file mode 100644 index 0000000..7397338 --- /dev/null +++ b/src/pyo3_rayon/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "pyo3_rayon" +version = "0.1.0" +edition = "2021" + +[lib] +name = "pyo3_rayon" +crate-type = ["cdylib"] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +csv = "1.3.0" +pyo3 = { version = "0.20.2", features = ["abi3-py38", "extension-module"] } +rayon = "1.8.1" +regex = "1.10.3" +serde = { version = "1.0.196", features = ["derive"] } diff --git a/src/pyo3_rayon/requirements.txt b/src/pyo3_rayon/requirements.txt new file mode 100644 index 0000000..2dbbe3d --- /dev/null +++ b/src/pyo3_rayon/requirements.txt @@ -0,0 +1 @@ +maturin==1.4.0 diff --git a/src/pyo3_rayon/src/lib.rs b/src/pyo3_rayon/src/lib.rs new file mode 100644 index 0000000..149d9be --- /dev/null +++ b/src/pyo3_rayon/src/lib.rs @@ -0,0 +1,67 @@ +use pyo3::prelude::*; +use rayon::prelude::*; +use regex::{Captures, Regex}; +use std::error::Error; + +#[derive(FromPyObject)] +struct Record { + id: u32, + content: String, +} + +struct RecordProcessed { + id: u32, + n_m: u64, + n_f: u64, +} + +fn clean_text(text: &str) -> String { + let pattern1 = Regex::new(r"([’'])(s|d|ll)").unwrap(); + // Replace pattern with text + let matched = pattern1.replace_all(text, |capture: &Captures| match &capture[2] { + "s" => " is", + "d" => " had", + "ll" => " will", + _ => "", + }); + // Remove non-alphabetic characters + let pattern2 = Regex::new(r"[^a-zA-Z\s]").unwrap(); + let clean_text = pattern2.replace_all(&matched, ""); + let result: String = clean_text.to_lowercase(); + result +} + +fn count_gendered_pronouns(text: &str) -> Result<(usize, usize), Box> { + let clean_text = clean_text(text); + let tokens = clean_text.split_whitespace().collect::>(); + let n_m = tokens + .par_iter() + .filter(|&x| *x == "he" || *x == "him" || *x == "his") + .count(); + let n_f = tokens + .par_iter() + .filter(|&x| *x == "she" || *x == "her" || *x == "hers") + .count(); + Ok((n_m, n_f)) +} + +#[pyfunction(signature = (records))] +fn get_pronoun_counts(records: Vec) -> PyResult> { + let mut result = vec![]; + for record in records { + let (n_m, n_f) = count_gendered_pronouns(&record.content).unwrap(); + let record_processed = RecordProcessed { + id: record.id, + n_m: n_m as u64, + n_f: n_f as u64, + }; + result.push(record_processed); + } + Ok(result) +} + +#[pymodule] +fn pyo3_rayon(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_function(wrap_pyfunction!(get_pronoun_counts, m)?)?; + Ok(()) +}