Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add serialization support to FstAddOn #162

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
125 changes: 109 additions & 16 deletions rustfst/src/algorithms/compose/add_on.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,38 @@
use std::fmt::Debug;
use std::io::Write;
use std::marker::PhantomData;
use std::sync::Arc;

use anyhow::Result;
use nom::combinator::verify;
use nom::IResult;

use crate::{NomCustomError, StateId, SymbolTable, Tr};
use crate::fst_properties::FstProperties;
use crate::fst_traits::{CoreFst, ExpandedFst, Fst, FstIntoIterator, FstIterator, StateIterator};
use crate::fst_properties::properties::EXPANDED;
use crate::fst_traits::{CoreFst, ExpandedFst, Fst, FstIntoIterator, FstIterator, SerializableFst, StateIterator};
use crate::parsers::{parse_bin_bool, parse_bin_i32, write_bin_bool, write_bin_i32};
use crate::parsers::bin_fst::fst_header::{FST_MAGIC_NUMBER, FstFlags, FstHeader, OpenFstString};
use crate::prelude::{SerializableSemiring, SerializeBinary};
use crate::semirings::Semiring;
use crate::{StateId, SymbolTable};

/// Adds an object of type T to an FST.
/// The resulting type is a new FST implementation.
#[derive(Debug, PartialEq, Clone)]
pub struct FstAddOn<F, T> {
pub struct FstAddOn<W, F, T>
where
W: Semiring,
F: Fst<W>
{
pub(crate) fst: F,
pub(crate) add_on: T,
w: PhantomData<W>,
fst_type: String
}

impl<F, T> FstAddOn<F, T> {
pub fn new(fst: F, add_on: T) -> Self {
Self { fst, add_on }
impl<W: Semiring, F: Fst<W>, T> FstAddOn<W, F, T> {
pub fn new(fst: F, add_on: T, fst_type: String) -> Self {
Self { fst, add_on, w: PhantomData, fst_type }
}

pub fn fst(&self) -> &F {
Expand All @@ -34,7 +48,7 @@ impl<F, T> FstAddOn<F, T> {
}
}

impl<W: Semiring, F: CoreFst<W>, T> CoreFst<W> for FstAddOn<F, T> {
impl<W: Semiring, F: Fst<W>, T> CoreFst<W> for FstAddOn<W, F, T> {
type TRS = F::TRS;

fn start(&self) -> Option<StateId> {
Expand Down Expand Up @@ -78,27 +92,27 @@ impl<W: Semiring, F: CoreFst<W>, T> CoreFst<W> for FstAddOn<F, T> {
}
}

impl<'a, F: StateIterator<'a>, T> StateIterator<'a> for FstAddOn<F, T> {
impl<'a, W: Semiring, F: Fst<W>, T> StateIterator<'a> for FstAddOn<W, F, T> {
type Iter = <F as StateIterator<'a>>::Iter;

fn states_iter(&'a self) -> Self::Iter {
self.fst.states_iter()
}
}

impl<'a, W, F, T> FstIterator<'a, W> for FstAddOn<F, T>
impl<'a, W, F, T> FstIterator<'a, W> for FstAddOn<W, F, T>
where
W: Semiring + 'a,
F: FstIterator<'a, W>,
F: Fst<W>,
{
type FstIter = F::FstIter;
type FstIter = <F as FstIterator<'a, W>>::FstIter;

fn fst_iter(&'a self) -> Self::FstIter {
self.fst.fst_iter()
}
}

impl<W, F, T: Debug> Fst<W> for FstAddOn<F, T>
impl<W, F, T: Debug> Fst<W> for FstAddOn<W, F, T>
where
W: Semiring,
F: Fst<W>,
Expand Down Expand Up @@ -128,7 +142,7 @@ where
}
}

impl<W, F, T> ExpandedFst<W> for FstAddOn<F, T>
impl<W, F, T> ExpandedFst<W> for FstAddOn<W, F, T>
where
W: Semiring,
F: ExpandedFst<W>,
Expand All @@ -139,16 +153,95 @@ where
}
}

impl<W, F, T> FstIntoIterator<W> for FstAddOn<F, T>
impl<W, F, T> FstIntoIterator<W> for FstAddOn<W, F, T>
where
W: Semiring,
F: FstIntoIterator<W>,
F: FstIntoIterator<W> + Fst<W> ,
T: Debug,
{
type TrsIter = F::TrsIter;
type FstIter = F::FstIter;
type FstIter = <F as FstIntoIterator<W>>::FstIter;

fn fst_into_iter(self) -> Self::FstIter {
self.fst.fst_into_iter()
}
}

static ADD_ON_MAGIC_NUMBER: i32 = 446681434;
static ADD_ON_MIN_FILE_VERSION: i32 = 1;
static ADD_ON_FILE_VERSION: i32 = 1;

impl<W, F, AO1, AO2> SerializeBinary for FstAddOn<W, F, (Option<Arc<AO1>>, Option<Arc<AO2>>)>
where
W: SerializableSemiring,
F: SerializableFst<W>,
AO1: SerializeBinary + Debug + Clone + PartialEq,
AO2: SerializeBinary + Debug + Clone + PartialEq,
{
fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {

let (i, hdr) = FstHeader::parse(
i,
ADD_ON_MIN_FILE_VERSION,
Option::<&str>::None,
Tr::<W>::tr_type(),
)?;

let (i, _) = verify(parse_bin_i32, |v: &i32| *v == ADD_ON_MAGIC_NUMBER)(i)?;
let (i, fst) = F::parse_binary(i)?;

let (i, _have_addon) = verify(parse_bin_bool, |v| *v)(i)?;

let (i, have_addon1) = parse_bin_bool(i)?;
let (i, add_on_1) = if have_addon1 {
let (s, a) = AO1::parse_binary(i)?;
(s, Some(a))
} else {
(i, None)
};
let (i, have_addon2) = parse_bin_bool(i)?;
let (i, add_on_2) = if have_addon2 {
let (s, a) = AO2::parse_binary(i)?;
(s, Some(a))
} else {
(i, None)
};

let add_on = (add_on_1.map(Arc::new), add_on_2.map(Arc::new));
let fst_add_on = FstAddOn::new(fst, add_on, hdr.fst_type.s().clone());
Ok((i, fst_add_on))
}

fn write_binary<WB: Write>(&self, writer: &mut WB) -> Result<()> {
let hdr = FstHeader {
magic_number: FST_MAGIC_NUMBER,
fst_type: OpenFstString::new(&self.fst_type),
tr_type: OpenFstString::new(Tr::<W>::tr_type()),
version: ADD_ON_FILE_VERSION,
flags: FstFlags::empty(),
properties: self.properties().bits() | EXPANDED,
start: -1,
num_states: 0,
num_trs: 0,
isymt: None,
osymt: None,
};
hdr.write(writer)?;
write_bin_i32(writer, ADD_ON_MAGIC_NUMBER)?;
self.fst.write_binary(writer)?;
write_bin_bool(writer, true)?;
if let Some(add_on) = self.add_on.0.as_ref() {
write_bin_bool(writer, true)?;
add_on.write_binary(writer)?;
} else {
write_bin_bool(writer, false)?;
}
if let Some(add_on) = self.add_on.1.as_ref() {
write_bin_bool(writer, true)?;
add_on.write_binary(writer)?;
} else {
write_bin_bool(writer, false)?;
}
Ok(())
}
}
69 changes: 68 additions & 1 deletion rustfst/src/algorithms/compose/interval_set.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::HashSet;
use std::io::Write;
use std::slice::Iter as IterSlice;
use std::vec::IntoIter as IntoIterVec;
use nom::IResult;
use nom::multi::count;
use superslice::Ext;
use unsafe_unwrap::UnsafeUnwrap;
use crate::NomCustomError;
use crate::parsers::{parse_bin_i32, parse_bin_i64, write_bin_i32, write_bin_i64};
use crate::prelude::SerializeBinary;

/// Half-open integral interval [a, b) of signed integers of type T.
#[derive(PartialEq, Clone, Eq, Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -47,6 +53,26 @@ impl Ord for IntInterval {
}
}

impl SerializeBinary for IntInterval {
fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {
let (i, begin) = parse_bin_i32(i).map(|(s, v)| (s, v as usize))?;
let (i, end) = parse_bin_i32(i).map(|(s, v)| (s, v as usize))?;
Ok((
i,
IntInterval {
begin,
end
},
))
}

fn write_binary<WB: Write>(&self, writer: &mut WB) -> anyhow::Result<()> {
write_bin_i32(writer, self.begin as i32)?;
write_bin_i32(writer, self.end as i32)?;
Ok(())
}
}

/// Stores IntIntervals in a vector. In addition, keeps the count of points in
/// all intervals.
#[derive(Clone, PartialOrd, PartialEq, Debug)]
Expand Down Expand Up @@ -95,11 +121,52 @@ impl VectorIntervalStore {
}
}

impl SerializeBinary for VectorIntervalStore {
fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {
let (i, interval_count) = parse_bin_i64(i).map(|(s, v)| (s, v as usize))?;
let (i, intervals) = count(IntInterval::parse_binary, interval_count)(i)?;
let (i, store_count) = parse_bin_i32(i)?;
Ok((
i,
VectorIntervalStore {
intervals,
count: Some(store_count as usize)
},
))
}

fn write_binary<WB: Write>(&self, writer: &mut WB) -> anyhow::Result<()> {
write_bin_i64(writer, self.intervals.len() as i64)?;
for interval in self.intervals.iter() {
interval.write_binary(writer)?;
}
write_bin_i32(writer, self.count.unwrap_or_default() as i32)?;
Ok(())
}
}

#[derive(PartialOrd, PartialEq, Default, Clone, Debug)]
pub struct IntervalSet {
pub(crate) intervals: VectorIntervalStore,
}

impl SerializeBinary for IntervalSet {
fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {
let (i, intervals) = VectorIntervalStore::parse_binary(i)?;
Ok((
i,
IntervalSet {
intervals
},
))
}

fn write_binary<WB: Write>(&self, writer: &mut WB) -> anyhow::Result<()> {
self.intervals.write_binary(writer)?;
Ok(())
}
}

impl IntervalSet {
pub fn len(&self) -> usize {
self.intervals.len()
Expand Down Expand Up @@ -149,7 +216,7 @@ impl IntervalSet {
elt.begin + 1 == elt.end
}

// Sorts, collapses overlapping and adjacent interals, and sets count.
// Sorts, collapses overlapping and adjacent intervals, and sets count.
pub fn normalize(&mut self) {
let intervals = &mut self.intervals.intervals;
intervals.sort();
Expand Down
73 changes: 72 additions & 1 deletion rustfst/src/algorithms/compose/label_reachable.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::io::Write;
use std::sync::Arc;

use anyhow::Result;
use nom::IResult;
use nom::multi::count;

use crate::algorithms::compose::{IntervalSet, StateReachable};
use crate::algorithms::tr_compares::{ILabelCompare, OLabelCompare};
Expand All @@ -11,7 +14,8 @@ use crate::fst_impls::VectorFst;
use crate::fst_properties::FstProperties;
use crate::fst_traits::{CoreFst, ExpandedFst, Fst, MutableFst};
use crate::semirings::Semiring;
use crate::{Label, StateId, Tr, Trs, EPS_LABEL, NO_LABEL, UNASSIGNED};
use crate::{Label, StateId, Tr, Trs, EPS_LABEL, NO_LABEL, UNASSIGNED, NomCustomError};
use crate::parsers::{parse_bin_bool, parse_bin_i64, parse_bin_i32, parse_bin_u32, SerializeBinary, write_bin_bool, write_bin_i64, write_bin_u32, write_bin_i32};

#[derive(Debug, Clone, PartialEq)]
pub struct LabelReachableData {
Expand Down Expand Up @@ -116,6 +120,73 @@ impl LabelReachableData {
}
}

fn parse_label_map(i: &[u8]) -> IResult<&[u8], HashMap<Label, Label>, NomCustomError<&[u8]>> {
let mut stream = i;
let r = parse_bin_i64(stream).map(|(s, v)| (s, v as usize))?;
stream = r.0;
let map_size = r.1;
let mut map = HashMap::with_capacity(map_size);
for _ in 0..map_size {
let r = parse_bin_i32(stream).map(|(s, v)| (s, v as Label))?;
let key = r.1;
let r = parse_bin_i32(r.0).map(|(s, v)| (s, v as Label))?;
stream = r.0;
let val = r.1;
map.insert(key, val);
}
Ok((stream, map))
}

fn write_label_map<WB: Write>(writer: &mut WB, map: &HashMap<Label, Label>) -> Result<()> {
write_bin_i64(writer, map.len() as i64)?;
for (k, v) in map.iter() {
write_bin_i32(writer, *k as i32)?;
write_bin_i32(writer, *v as i32)?;
}
Ok(())
}

impl SerializeBinary for LabelReachableData {
fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {
let (i, reach_input) = parse_bin_bool(i)?;
let (i, have_relabel_data) = parse_bin_bool(i)?;
let (i, label2index) = if have_relabel_data {
parse_label_map(i)?
} else {
(i, Default::default())
};
let (i, final_label) = parse_bin_u32(i).map(|(s, v)| (s, v as Label))?;
let (i, set_count) = parse_bin_i64(i).map(|(s, v)| (s, v as usize))?;
let (i, interval_sets) = count(IntervalSet::parse_binary, set_count)(i)?;
Ok((
i,
LabelReachableData {
reach_input,
final_label,
label2index,
interval_sets
}
))
}

fn write_binary<WB: Write>(&self, writer: &mut WB) -> Result<()> {
write_bin_bool(writer, self.reach_input)?;
// OpenFst checks keep_relabel_data here which is missing in this struct.
// Instead we check if we have any data in label2index;
let have_relabel_data = !self.label2index.is_empty();
write_bin_bool(writer, have_relabel_data)?;
if have_relabel_data {
write_label_map(writer, &self.label2index)?;
}
write_bin_u32(writer, self.final_label as u32)?;
write_bin_i64(writer, self.interval_sets.len() as i64)?;
for interval_set in self.interval_sets.iter() {
interval_set.write_binary(writer)?;
}
Ok(())
}
}

#[derive(Debug, Clone, PartialEq)]
pub struct LabelReachable {
data: Arc<LabelReachableData>,
Expand Down
Loading