From 68d04a50ceaf2eb2f20911282c1ec3ff219cb543 Mon Sep 17 00:00:00 2001 From: Albert Garde Date: Sun, 28 Jan 2024 15:26:03 +0100 Subject: [PATCH] :art: Change `cargo fmt` configuration --- rustfmt.toml | 4 + src/data/data_objects/metadata_object.rs | 3 +- src/data/data_objects/neuron2graph.rs | 79 ++++++++--- .../neuroscope/neuroscope_page.rs | 1 - src/data/database/data_types/json.rs | 42 ++++-- src/data/database/data_types/neuron2graph.rs | 18 ++- .../database/data_types/neuron_explainer.rs | 36 +++-- src/data/database/data_types/neuron_store.rs | 22 ++- src/data/database/data_types/neuroscope.rs | 58 +++++--- src/data/database/mod.rs | 5 +- src/data/database/model_handle.rs | 32 +++-- src/data/database/service_handle.rs | 7 +- src/data/database/validation.rs | 3 +- src/data/neuron_store.rs | 134 +++++++++++++----- src/data/retrieve/json.rs | 14 +- src/data/retrieve/neuron2graph.rs | 42 ++++-- src/data/retrieve/neuron_explainer.rs | 37 ++++- src/data/retrieve/neuron_store.rs | 54 ++++--- src/data/retrieve/neuroscope.rs | 67 ++++++--- src/index.rs | 19 ++- src/logging.rs | 7 +- src/pyo3/database.rs | 14 +- src/pyo3/model_handle.rs | 48 ++++--- src/pyo3/service_provider.rs | 3 +- src/server/response.rs | 23 ++- src/server/service.rs | 3 +- src/server/service_providers/json.rs | 58 +++++--- src/server/service_providers/metadata.rs | 3 +- src/server/service_providers/neuron2graph.rs | 18 ++- .../service_providers/neuron2graph_search.rs | 17 ++- .../service_providers/neuron_explainer.rs | 30 ++-- src/server/service_providers/neuroscope.rs | 77 ++++++---- src/server/start.rs | 1 - 33 files changed, 657 insertions(+), 322 deletions(-) create mode 100644 rustfmt.toml diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..ad56976 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,4 @@ +unstable_features = true +imports_granularity = "Crate" +group_imports = "StdExternalCrate" +format_strings = true diff --git a/src/data/data_objects/metadata_object.rs b/src/data/data_objects/metadata_object.rs index 2c15815..050ecda 100644 --- a/src/data/data_objects/metadata_object.rs +++ b/src/data/data_objects/metadata_object.rs @@ -1,9 +1,8 @@ use anyhow::Result; use serde::{Deserialize, Serialize}; -use crate::data::{Metadata, ModelHandle}; - use super::{data_object, DataObject}; +use crate::data::{Metadata, ModelHandle}; #[derive(Clone, Serialize, Deserialize)] pub struct MetadataObject { diff --git a/src/data/data_objects/neuron2graph.rs b/src/data/data_objects/neuron2graph.rs index d5bec9d..141db6b 100644 --- a/src/data/data_objects/neuron2graph.rs +++ b/src/data/data_objects/neuron2graph.rs @@ -12,9 +12,8 @@ use graphviz_rust::{ use itertools::Itertools; use serde::{Deserialize, Serialize}; -use crate::data::SimilarNeurons; - use super::{data_object, DataObject}; +use crate::data::SimilarNeurons; fn id_to_str(id: &Id) -> &str { match id { @@ -24,9 +23,13 @@ fn id_to_str(id: &Id) -> &str { fn id_to_usize(id: &Id) -> Result { let id_string = id_to_str(id); - id_string.parse::().with_context(|| format!( - "Could not parse node id {} as usize. It is assumed that all N2G graphs only use positive integer node ids.", id_string - )) + id_string.parse::().with_context(|| { + format!( + "Could not parse node id {} as usize. It is assumed that all N2G graphs only use \ + positive integer node ids.", + id_string + ) + }) } fn dot_node_to_id_label_importance(node: &DotNode) -> Result<(usize, String, f32)> { @@ -40,15 +43,26 @@ fn dot_node_to_id_label_importance(node: &DotNode) -> Result<(usize, String, f32 .find(|Attribute(key, _)| id_to_str(key) == "label") .with_context(|| format!("Node with id {id} has no attribute 'label'."))?; // Assume that the `fillcolor` attribute is a 9 character string with '"' enclosing a hexadecimal color code. - let color_str = get_attribute(attributes.as_slice(), "fillcolor").with_context(|| format!( - "Node {id} has no attribute 'fillcolor'. It is assumed that all N2G nodes have a 'fillcolor' attribute that signifies their importance." - ))?; - let importance_hex = color_str.get(4..6).with_context(|| format!( - "The 'fillcolor' attribute of node {id} is insufficiently long. It is expected to be 9 characters long." - ))?; - let importance = 1.-u8::from_str_radix(importance_hex, 16).with_context(|| format!( - "The green part of the 'fillcolor' attribute of node {id} is not a valid hexadecimal number." - ))? as f32 / 255.0; + let color_str = get_attribute(attributes.as_slice(), "fillcolor").with_context(|| { + format!( + "Node {id} has no attribute 'fillcolor'. It is assumed that all N2G nodes have a \ + 'fillcolor' attribute that signifies their importance." + ) + })?; + let importance_hex = color_str.get(4..6).with_context(|| { + format!( + "The 'fillcolor' attribute of node {id} is insufficiently long. It is expected to be \ + 9 characters long." + ) + })?; + let importance = 1. + - u8::from_str_radix(importance_hex, 16).with_context(|| { + format!( + "The green part of the 'fillcolor' attribute of node {id} is not a valid \ + hexadecimal number." + ) + })? as f32 + / 255.0; let label = id_to_str(label_id).to_string(); Ok((id, label, importance)) @@ -70,7 +84,12 @@ fn subgraph_to_nodes(subgraph: &Subgraph) -> Result> { let id_str = id_to_str(id); let id: usize = id_str .strip_prefix("cluster_") - .with_context(|| format!("It is assumed that all N2G subgraphs have ids starting with 'cluster_'. Subgraph id: {id_str}"))? + .with_context(|| { + format!( + "It is assumed that all N2G subgraphs have ids starting with 'cluster_'. Subgraph \ + id: {id_str}" + ) + })? .parse::() .with_context(|| format!("Failed to parse subgraph id '{id_str}' as usize."))?; let nodes = statements @@ -92,16 +111,34 @@ fn dot_edge_to_ids( ) -> Result<(usize, usize)> { match edge_ty { EdgeTy::Pair(Vertex::N(NodeId(node_id1, _)), Vertex::N(NodeId(node_id2, _))) => { - let id1 = id_to_usize(node_id1).with_context(|| format!("Failed to parse first id for edge {edge_ty:?}."))?; - let id2 = id_to_usize(node_id2).with_context(|| format!("Failed to parse second id for edge {edge_ty:?}."))?; + let id1 = id_to_usize(node_id1) + .with_context(|| format!("Failed to parse first id for edge {edge_ty:?}."))?; + let id2 = id_to_usize(node_id2) + .with_context(|| format!("Failed to parse second id for edge {edge_ty:?}."))?; match get_attribute(attributes, "dir") { Some("back") => Ok((id2, id1)), - None => bail!("No direction attribute found for edge {id1}->{id2}. It is assumed that all N2G graphs only use edges with direction 'back'."), - _ => bail!("Only edges with direction 'back' or 'forward' are supported. It is assumed that all N2G graphs only use edges with direction 'back' or 'forward'. Edge: {:?}", edge_ty) + None => bail!( + "No direction attribute found for edge {id1}->{id2}. It is assumed that all \ + N2G graphs only use edges with direction 'back'." + ), + _ => bail!( + "Only edges with direction 'back' or 'forward' are supported. It is assumed \ + that all N2G graphs only use edges with direction 'back' or 'forward'. Edge: \ + {:?}", + edge_ty + ), } } - EdgeTy::Pair(_, _) => bail!("Only edges between individual nodes are supported. It is assumed that N2G does not use edges between subgraphs. Edge: {:?}", edge_ty), - EdgeTy::Chain(_) => bail!("Only pair edges are supported. It is assumed that all N2G graphs only use pair edges. Edge: {:?}", edge_ty) + EdgeTy::Pair(_, _) => bail!( + "Only edges between individual nodes are supported. It is assumed that N2G does not \ + use edges between subgraphs. Edge: {:?}", + edge_ty + ), + EdgeTy::Chain(_) => bail!( + "Only pair edges are supported. It is assumed that all N2G graphs only use pair \ + edges. Edge: {:?}", + edge_ty + ), } } diff --git a/src/data/data_objects/neuroscope/neuroscope_page.rs b/src/data/data_objects/neuroscope/neuroscope_page.rs index f60115c..a27d449 100644 --- a/src/data/data_objects/neuroscope/neuroscope_page.rs +++ b/src/data/data_objects/neuroscope/neuroscope_page.rs @@ -4,7 +4,6 @@ use anyhow::{bail, Context, Result}; use itertools::Itertools; use regex::Regex; use serde::{Deserialize, Serialize}; - use utoipa::ToSchema; use crate::data::{ diff --git a/src/data/database/data_types/json.rs b/src/data/database/data_types/json.rs index 608c5a9..a4a76bb 100644 --- a/src/data/database/data_types/json.rs +++ b/src/data/database/data_types/json.rs @@ -1,5 +1,7 @@ +use anyhow::{bail, Context, Result}; use async_trait::async_trait; +use super::{data_type::DataValidationError, DataTypeDiscriminants, ModelDataType}; use crate::{ data::{ data_objects::{DataObject, JsonData}, @@ -8,10 +10,6 @@ use crate::{ Index, }; -use super::{data_type::DataValidationError, DataTypeDiscriminants, ModelDataType}; - -use anyhow::{bail, Context, Result}; - pub struct Json { model: ModelHandle, data_type: DataTypeHandle, @@ -81,26 +79,42 @@ impl Json { pub async fn layer_page(&self, layer_index: u32) -> Result { let model_name = self.model.name(); let data_type_name = self.data_type.name(); - let raw_data = self.model - .layer_data( &self.data_type, layer_index) - .await.with_context(|| { - format!("Failed to get '{data_type_name}' layer data for layer {layer_index} in model '{model_name}'.") + let raw_data = self + .model + .layer_data(&self.data_type, layer_index) + .await + .with_context(|| { + format!( + "Failed to get '{data_type_name}' layer data for layer {layer_index} in model \ + '{model_name}'." + ) })? .with_context(|| { - format!("Database has no '{data_type_name}' layer data for layer {layer_index} in model '{model_name}'.") + format!( + "Database has no '{data_type_name}' layer data for layer {layer_index} in \ + model '{model_name}'." + ) })?; JsonData::from_binary(raw_data.as_slice()) } pub async fn neuron_page(&self, layer_index: u32, neuron_index: u32) -> Result { let model_name = self.model.name(); let data_type_name = self.data_type.name(); - let raw_data = self.model - .neuron_data( &self.data_type, layer_index, neuron_index) - .await.with_context(|| { - format!("Failed to get '{data_type_name}' neuron data for neuron l{layer_index}n{neuron_index} in model '{model_name}'.") + let raw_data = self + .model + .neuron_data(&self.data_type, layer_index, neuron_index) + .await + .with_context(|| { + format!( + "Failed to get '{data_type_name}' neuron data for neuron \ + l{layer_index}n{neuron_index} in model '{model_name}'." + ) })? .with_context(|| { - format!("Database has no '{data_type_name}' neuron data for neuron l{layer_index}n{neuron_index} in model '{model_name}'.") + format!( + "Database has no '{data_type_name}' neuron data for neuron \ + l{layer_index}n{neuron_index} in model '{model_name}'." + ) })?; JsonData::from_binary(raw_data.as_slice()) } diff --git a/src/data/database/data_types/neuron2graph.rs b/src/data/database/data_types/neuron2graph.rs index b4a224f..bf0f392 100644 --- a/src/data/database/data_types/neuron2graph.rs +++ b/src/data/database/data_types/neuron2graph.rs @@ -1,13 +1,12 @@ use anyhow::{bail, Context, Result}; use async_trait::async_trait; +use super::{data_type::DataValidationError, DataTypeDiscriminants, ModelDataType}; use crate::data::{ data_objects::{DataObject, Graph}, DataTypeHandle, ModelHandle, }; -use super::{data_type::DataValidationError, DataTypeDiscriminants, ModelDataType}; - pub struct Neuron2Graph { model: ModelHandle, data_type: DataTypeHandle, @@ -47,12 +46,21 @@ impl ModelDataType for Neuron2Graph { impl Neuron2Graph { pub async fn neuron_graph(&self, layer_index: u32, neuron_index: u32) -> Result { let model_name = self.model.name(); - let raw_data = self.model + let raw_data = self + .model .neuron_data(&self.data_type, layer_index, neuron_index) .await? .with_context(|| { - format!("Database has no neuron2graph data for neuron l{layer_index}n{neuron_index} in model '{model_name}'") + format!( + "Database has no neuron2graph data for neuron l{layer_index}n{neuron_index} \ + in model '{model_name}'" + ) })?; - Graph::from_binary(raw_data).with_context(|| format!("Failed to unpack neuron2graph graph for neuron l{layer_index}n{neuron_index} in model '{model_name}'.")) + Graph::from_binary(raw_data).with_context(|| { + format!( + "Failed to unpack neuron2graph graph for neuron l{layer_index}n{neuron_index} in \ + model '{model_name}'." + ) + }) } } diff --git a/src/data/database/data_types/neuron_explainer.rs b/src/data/database/data_types/neuron_explainer.rs index e3a4b46..83c0ef1 100644 --- a/src/data/database/data_types/neuron_explainer.rs +++ b/src/data/database/data_types/neuron_explainer.rs @@ -1,17 +1,16 @@ use anyhow::{bail, Context, Result}; use async_trait::async_trait; +use super::{ + data_type::{DataValidationError, ModelDataType}, + DataTypeDiscriminants, +}; use crate::data::{ data_objects::{DataObject, NeuronExplainerPage}, database::ModelHandle, DataTypeHandle, }; -use super::{ - data_type::{DataValidationError, ModelDataType}, - DataTypeDiscriminants, -}; - pub struct NeuronExplainer { model: ModelHandle, data_type: DataTypeHandle, @@ -55,14 +54,25 @@ impl NeuronExplainer { neuron_index: u32, ) -> Result> { let model_name = self.model.name(); - let raw_data = self.model - .neuron_data( &self.data_type, layer_index, neuron_index) - .await.with_context(|| { - format!("Failed to get neuron explainer neuron data for neuron l{layer_index}n{neuron_index} in model '{model_name}'.") - })?; - raw_data.map(|raw_data| NeuronExplainerPage::from_binary(raw_data.as_slice()) + let raw_data = self + .model + .neuron_data(&self.data_type, layer_index, neuron_index) + .await .with_context(|| { - format!("Failed to deserialize neuron explainer neuron data for neuron l{layer_index}n{neuron_index} in model '{model_name}'.") - })).transpose() + format!( + "Failed to get neuron explainer neuron data for neuron \ + l{layer_index}n{neuron_index} in model '{model_name}'." + ) + })?; + raw_data + .map(|raw_data| { + NeuronExplainerPage::from_binary(raw_data.as_slice()).with_context(|| { + format!( + "Failed to deserialize neuron explainer neuron data for neuron \ + l{layer_index}n{neuron_index} in model '{model_name}'." + ) + }) + }) + .transpose() } } diff --git a/src/data/database/data_types/neuron_store.rs b/src/data/database/data_types/neuron_store.rs index eb85635..483550b 100644 --- a/src/data/database/data_types/neuron_store.rs +++ b/src/data/database/data_types/neuron_store.rs @@ -1,12 +1,11 @@ use anyhow::{bail, Context, Result}; use async_trait::async_trait; +use super::{data_type::DataValidationError, DataTypeDiscriminants, ModelDataType}; use crate::data::{ neuron_store::SimilarNeurons, DataTypeHandle, ModelHandle, NeuronStore as NeuronStoreData, }; -use super::{data_type::DataValidationError, DataTypeDiscriminants, ModelDataType}; - pub struct NeuronStore { model: ModelHandle, data_type: DataTypeHandle, @@ -68,10 +67,23 @@ impl NeuronStore { .model .neuron_data(&self.data_type, layer_index, neuron_index) .await - .with_context(|| format!("Failed to get neuron store data for neuron l{layer_index}n{neuron_index} in model '{model_name}'.",))? .with_context(|| { - format!("Database has no neuron store data for neuron l{layer_index}n{neuron_index} in model '{model_name}'.") + format!( + "Failed to get neuron store data for neuron l{layer_index}n{neuron_index} in \ + model '{model_name}'.", + ) + })? + .with_context(|| { + format!( + "Database has no neuron store data for neuron l{layer_index}n{neuron_index} \ + in model '{model_name}'." + ) })?; - SimilarNeurons::from_binary(raw_data.as_slice()).with_context(|| format!("Failed to deserialize neuron similarities for neuron l{layer_index}n{neuron_index} in model '{model_name}'.")) + SimilarNeurons::from_binary(raw_data.as_slice()).with_context(|| { + format!( + "Failed to deserialize neuron similarities for neuron \ + l{layer_index}n{neuron_index} in model '{model_name}'." + ) + }) } } diff --git a/src/data/database/data_types/neuroscope.rs b/src/data/database/data_types/neuroscope.rs index 27b7a3f..1079de4 100644 --- a/src/data/database/data_types/neuroscope.rs +++ b/src/data/database/data_types/neuroscope.rs @@ -1,18 +1,15 @@ use anyhow::{bail, Context, Result}; use async_trait::async_trait; -use crate::data::{ - data_objects::NeuroscopeLayerPage, - data_objects::NeuroscopeNeuronPage, - data_objects::{DataObject, NeuroscopeModelPage}, - database::ModelHandle, - DataTypeHandle, -}; - use super::{ data_type::{DataValidationError, ModelDataType}, DataTypeDiscriminants, }; +use crate::data::{ + data_objects::{DataObject, NeuroscopeLayerPage, NeuroscopeModelPage, NeuroscopeNeuronPage}, + database::ModelHandle, + DataTypeHandle, +}; pub struct Neuroscope { model: ModelHandle, @@ -63,13 +60,21 @@ impl Neuroscope { } pub async fn layer_page(&self, layer_index: u32) -> Result { let model_name = self.model.name(); - let raw_data = self.model - .layer_data( &self.data_type, layer_index) - .await.with_context(|| { - format!("Failed to get neuroscope layer data for layer {layer_index} in model '{model_name}'.") + let raw_data = self + .model + .layer_data(&self.data_type, layer_index) + .await + .with_context(|| { + format!( + "Failed to get neuroscope layer data for layer {layer_index} in model \ + '{model_name}'." + ) })? .with_context(|| { - format!("Database has no neuroscope layer data for layer {layer_index} in model '{model_name}'.") + format!( + "Database has no neuroscope layer data for layer {layer_index} in model \ + '{model_name}'." + ) })?; NeuroscopeLayerPage::from_binary(raw_data.as_slice()) } @@ -79,14 +84,25 @@ impl Neuroscope { neuron_index: u32, ) -> Result> { let model_name = self.model.name(); - let raw_data = self.model - .neuron_data( &self.data_type, layer_index, neuron_index) - .await.with_context(|| { - format!("Failed to get neuroscope neuron data for neuron l{layer_index}n{neuron_index} in model '{model_name}'.") - })?; - raw_data.map(|raw_data| NeuroscopeNeuronPage::from_binary(raw_data.as_slice()) + let raw_data = self + .model + .neuron_data(&self.data_type, layer_index, neuron_index) + .await .with_context(|| { - format!("Failed to deserialize neuroscope neuron data for neuron l{layer_index}n{neuron_index} in model '{model_name}'.") - })).transpose() + format!( + "Failed to get neuroscope neuron data for neuron \ + l{layer_index}n{neuron_index} in model '{model_name}'." + ) + })?; + raw_data + .map(|raw_data| { + NeuroscopeNeuronPage::from_binary(raw_data.as_slice()).with_context(|| { + format!( + "Failed to deserialize neuroscope neuron data for neuron \ + l{layer_index}n{neuron_index} in model '{model_name}'." + ) + }) + }) + .transpose() } } diff --git a/src/data/database/mod.rs b/src/data/database/mod.rs index 1a348bd..f4bafbd 100644 --- a/src/data/database/mod.rs +++ b/src/data/database/mod.rs @@ -1,15 +1,12 @@ use std::path::Path; use anyhow::{bail, Context, Result}; - use rusqlite::Transaction; use tokio_rusqlite::Connection; -use crate::server::{Service, ServiceProvider}; - use self::data_types::ModelDataType; - use super::{data_types::DataType, Metadata}; +use crate::server::{Service, ServiceProvider}; mod model_handle; pub use model_handle::ModelHandle; diff --git a/src/data/database/model_handle.rs b/src/data/database/model_handle.rs index baa6c40..b19b8f5 100644 --- a/src/data/database/model_handle.rs +++ b/src/data/database/model_handle.rs @@ -1,11 +1,10 @@ -use crate::{data::Metadata, Index}; +use anyhow::{Context, Result}; +use rusqlite::OptionalExtension; use super::{ data_types::ModelDataType, service_handle::ServiceHandle, DataTypeHandle, Database, Operation, }; - -use anyhow::{Context, Result}; -use rusqlite::OptionalExtension; +use crate::{data::Metadata, Index}; #[derive(Clone)] pub struct ModelHandle { @@ -269,14 +268,21 @@ impl ModelHandle { .await .context("Failed to get list of services.")? { - if self.missing_data_types(&service).await.with_context(|| - format!("Failed to get list of missing data objects for model '{model_name}' and service '{service_name}'.", - model_name = self.name(), - service_name = service.name() + if self + .missing_data_types(&service) + .await + .with_context(|| { + format!( + "Failed to get list of missing data objects for model '{model_name}' and \ + service '{service_name}'.", + model_name = self.name(), + service_name = service.name() ) - )?.is_empty() { + })? + .is_empty() + { services.push(service); - } + } } Ok(services) } @@ -466,7 +472,8 @@ impl ModelHandle { .await .with_context(|| { format!( - "Failed to get layer data for layer {layer_index} data object '{}' for model '{}'.", + "Failed to get layer data for layer {layer_index} data object '{}' for model \ + '{}'.", self.name(), data_type.name() ) @@ -497,7 +504,8 @@ impl ModelHandle { .await .with_context(|| { format!( - "Failed to get neuron data for neuron l{layer_index}n{neuron_index} for data object '{}' for model '{}'.", + "Failed to get neuron data for neuron l{layer_index}n{neuron_index} for data \ + object '{}' for model '{}'.", data_type.name(), self.name(), ) diff --git a/src/data/database/service_handle.rs b/src/data/database/service_handle.rs index c6bc6db..41aa1d0 100644 --- a/src/data/database/service_handle.rs +++ b/src/data/database/service_handle.rs @@ -1,10 +1,9 @@ -use crate::server::{Service, ServiceProvider}; - -use super::{DataTypeHandle, Database, Operation}; - use anyhow::{Context, Result}; use rusqlite::OptionalExtension; +use super::{DataTypeHandle, Database, Operation}; +use crate::server::{Service, ServiceProvider}; + #[derive(Clone)] pub struct ServiceHandle { id: i64, diff --git a/src/data/database/validation.rs b/src/data/database/validation.rs index 134243b..e743bd0 100644 --- a/src/data/database/validation.rs +++ b/src/data/database/validation.rs @@ -1,8 +1,7 @@ use anyhow::{bail, Context}; -use crate::{data::NeuronIndex, Index}; - use super::{DataTypeHandle, ModelHandle}; +use crate::{data::NeuronIndex, Index}; impl ModelHandle { pub async fn missing_model_items( diff --git a/src/data/neuron_store.rs b/src/data/neuron_store.rs index 0fa8013..fc0a7d3 100644 --- a/src/data/neuron_store.rs +++ b/src/data/neuron_store.rs @@ -1,14 +1,16 @@ -use std::collections::{HashMap, HashSet}; -use std::fs; -use std::io::Write; -use std::path::Path; -use std::time::Instant; -use std::{fmt::Display, str::FromStr}; +use std::{ + collections::{HashMap, HashSet}, + fmt::Display, + fs, + io::Write, + path::Path, + str::FromStr, + time::Instant, +}; use anyhow::{bail, Context, Result}; use itertools::Itertools; -use serde::Deserialize; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use snap::raw::{Decoder, Encoder}; use super::NeuronIndex; @@ -108,7 +110,9 @@ pub struct NeuronSimilarity { impl NeuronSimilarity { pub fn similar_neurons(&self, neuron_index: NeuronIndex) -> Result<&SimilarNeurons> { let index = neuron_index.flat_index(self.layer_size); - self.similar_neurons.get(index).with_context(|| format!("No similar neuron array for neuron index {neuron_index}.")) + self.similar_neurons + .get(index) + .with_context(|| format!("No similar neuron array for neuron index {neuron_index}.")) } } @@ -148,9 +152,15 @@ impl NeuronStore { for (token_id, (_token, neuron_indices)) in self.activating.iter().enumerate() { for &neuron_index in neuron_indices { let index = neuron_index.flat_index(self.layer_size); - activating_tokens.get_mut(index).with_context(|| - format!("Index {index} somehow greater than the number of neurons. This should not be possible") - )?.push(token_id as u32); + activating_tokens + .get_mut(index) + .with_context(|| { + format!( + "Index {index} somehow greater than the number of neurons. This \ + should not be possible" + ) + })? + .push(token_id as u32); } } @@ -160,62 +170,110 @@ impl NeuronStore { for (token_id, (_token, neuron_indices)) in self.important.iter().enumerate() { for &neuron_index in neuron_indices { let index = neuron_index.flat_index(self.layer_size); - important_tokens.get_mut(index).with_context(|| - format!("Index {index} somehow greater than the number of neurons. This should not be possible") - )?.push(token_id as u32); + important_tokens + .get_mut(index) + .with_context(|| { + format!( + "Index {index} somehow greater than the number of neurons. This \ + should not be possible" + ) + })? + .push(token_id as u32); } } println!("Finding similar neurons..."); std::io::stdout().flush().unwrap(); - let mut similar_neurons: Vec = (0..num_neurons).map(|_| SimilarNeurons { similar_neurons: vec![] }).collect(); + let mut similar_neurons: Vec = (0..num_neurons) + .map(|_| SimilarNeurons { + similar_neurons: vec![], + }) + .collect(); let start = Instant::now(); - for (neuron, (this_activating_tokens, this_important_tokens)) in activating_tokens.iter().zip(important_tokens.iter()).enumerate() { + for (neuron, (this_activating_tokens, this_important_tokens)) in activating_tokens + .iter() + .zip(important_tokens.iter()) + .enumerate() + { let this_neuron_index = NeuronIndex::from_flat_index(self.layer_size, neuron); let neurons_per_second = (neuron as f32) / start.elapsed().as_secs_f32(); - print!("Neuron {this_neuron_index}. Neurons per second: {neurons_per_second:.0} \r"); - for (other_neuron, (other_activating_tokens, other_important_tokens)) in activating_tokens.iter().zip(important_tokens.iter()).enumerate().skip(neuron + 1) { + print!( + "Neuron {this_neuron_index}. Neurons per second: {neurons_per_second:.0} \r" + ); + for (other_neuron, (other_activating_tokens, other_important_tokens)) in + activating_tokens + .iter() + .zip(important_tokens.iter()) + .enumerate() + .skip(neuron + 1) + { let mut total_common = 0; - + let mut this_activating_iter = this_activating_tokens.iter().copied().peekable(); let mut other_activating_iter = other_activating_tokens.iter().copied().peekable(); - while let (Some(this_activating), Some(other_activating)) = (this_activating_iter.peek(), other_activating_iter.peek()) { + while let (Some(this_activating), Some(other_activating)) = + (this_activating_iter.peek(), other_activating_iter.peek()) + { match this_activating.cmp(other_activating) { - std::cmp::Ordering::Less => { this_activating_iter.next(); }, + std::cmp::Ordering::Less => { + this_activating_iter.next(); + } std::cmp::Ordering::Equal => { total_common += 1; this_activating_iter.next(); other_activating_iter.next(); - - }, - std::cmp::Ordering::Greater => { other_activating_iter.next(); }, + } + std::cmp::Ordering::Greater => { + other_activating_iter.next(); + } } } let mut this_important_iter = this_important_tokens.iter().copied().peekable(); let mut other_important_iter = other_important_tokens.iter().copied().peekable(); - while let (Some(this_important), Some(other_important)) = (this_important_iter.peek(), other_important_iter.peek()) { + while let (Some(this_important), Some(other_important)) = + (this_important_iter.peek(), other_important_iter.peek()) + { match this_important.cmp(other_important) { - std::cmp::Ordering::Less => { this_important_iter.next(); }, + std::cmp::Ordering::Less => { + this_important_iter.next(); + } std::cmp::Ordering::Equal => { total_common += 1; this_important_iter.next(); other_important_iter.next(); - - }, - std::cmp::Ordering::Greater => { other_important_iter.next(); }, + } + std::cmp::Ordering::Greater => { + other_important_iter.next(); + } } } - let possible_common = activating_tokens.len().min(other_activating_tokens.len()) + important_tokens.len().min(other_important_tokens.len()); + let possible_common = activating_tokens.len().min(other_activating_tokens.len()) + + important_tokens.len().min(other_important_tokens.len()); let similarity = (total_common as f32) / (possible_common as f32); if similarity >= threshold { - let other_neuron_index = NeuronIndex::from_flat_index(self.layer_size, other_neuron); - similar_neurons.get_mut(neuron).with_context(|| - format!("Index {neuron} of neuron somehow greater than the number of neurons. This should not be possible") - )?.similar_neurons.push(SimilarNeuron::new(other_neuron_index, similarity)); - similar_neurons.get_mut(other_neuron).with_context(|| - format!("Index {other_neuron} of other neuron somehow greater than the number of neurons. This should not be possible") - )?.similar_neurons.push(SimilarNeuron::new(this_neuron_index, similarity)); + let other_neuron_index = + NeuronIndex::from_flat_index(self.layer_size, other_neuron); + similar_neurons + .get_mut(neuron) + .with_context(|| { + format!( + "Index {neuron} of neuron somehow greater than the number of \ + neurons. This should not be possible" + ) + })? + .similar_neurons + .push(SimilarNeuron::new(other_neuron_index, similarity)); + similar_neurons + .get_mut(other_neuron) + .with_context(|| { + format!( + "Index {other_neuron} of other neuron somehow greater than the \ + number of neurons. This should not be possible" + ) + })? + .similar_neurons + .push(SimilarNeuron::new(this_neuron_index, similarity)); } } } diff --git a/src/data/retrieve/json.rs b/src/data/retrieve/json.rs index f4c131b..bac6635 100644 --- a/src/data/retrieve/json.rs +++ b/src/data/retrieve/json.rs @@ -1,3 +1,5 @@ +use anyhow::{bail, Context, Result}; + use crate::{ data::{ data_objects::{DataObject, JsonData}, @@ -7,8 +9,6 @@ use crate::{ Index, }; -use anyhow::{bail, Context, Result}; - pub async fn store_json_data( model_handle: &mut ModelHandle, data_type_handle: &DataTypeHandle, @@ -26,9 +26,13 @@ pub async fn store_json_data( .add_data( data_type_handle, index, - data.to_binary().with_context(|| - format!("Failed to serialize JSON data of data object '{data_type_name}' for {index} in model '{model_name}'.", index = index.error_string()) - )?, + data.to_binary().with_context(|| { + format!( + "Failed to serialize JSON data of data object '{data_type_name}' for {index} \ + in model '{model_name}'.", + index = index.error_string() + ) + })?, ) .await } diff --git a/src/data/retrieve/neuron2graph.rs b/src/data/retrieve/neuron2graph.rs index 26e01ab..7523aff 100644 --- a/src/data/retrieve/neuron2graph.rs +++ b/src/data/retrieve/neuron2graph.rs @@ -1,15 +1,15 @@ use std::path::{Path, PathBuf}; +use anyhow::{bail, Context, Result}; +use regex::Regex; +use tokio::fs; + use crate::data::{ data_objects::{DataObject, Graph}, data_types::DataType, DataTypeHandle, ModelHandle, NeuronIndex, }; -use anyhow::{bail, Context, Result}; -use regex::Regex; -use tokio::fs; - fn neuron_path(root: impl AsRef, neuron_index: NeuronIndex) -> PathBuf { let NeuronIndex { layer, neuron } = neuron_index; root.as_ref() @@ -32,19 +32,32 @@ async fn retrieve_neuron2graph_neuron( let graph = match graphviz_rust::parse(graph_str.as_ref()) { Ok(graph) => graph, Err(parse_error) => { - bail!("Failed to parse graph for neuron {neuron_index} in model '{}'. Error: '{parse_error}'", model_handle.name()) + bail!( + "Failed to parse graph for neuron {neuron_index} in model '{}'. Error: \ + '{parse_error}'", + model_handle.name() + ) } }; - Graph::from_dot(graph).with_context(|| - format!("Succesfully parsed graph, but graph is not a valid neuron2graph grpah. Neuron {neuron_index} in model '{}'.", model_handle.name()) - )? + Graph::from_dot(graph).with_context(|| { + format!( + "Succesfully parsed graph, but graph is not a valid neuron2graph grpah. \ + Neuron {neuron_index} in model '{}'.", + model_handle.name() + ) + })? } Err(read_err) => { if read_err.kind() == std::io::ErrorKind::NotFound { return Ok(false); } else { - return Err(read_err) - .with_context(|| format!("Failed to read neuron2graph graph file for neuron {neuron_index} in model '{}'.", model_handle.name())); + return Err(read_err).with_context(|| { + format!( + "Failed to read neuron2graph graph file for neuron {neuron_index} in \ + model '{}'.", + model_handle.name() + ) + }); } } }; @@ -63,7 +76,6 @@ pub async fn retrieve_neuron2graph( model_handle: &mut ModelHandle, path: impl AsRef, ) -> Result<()> { - let path = path.as_ref(); if !path.is_dir() { @@ -106,9 +118,7 @@ pub async fn retrieve_neuron2graph( print!("Storing neuron graphs: 0/{num_total_neurons}"); let mut num_missing = 0; for neuron_index in model_handle.metadata().neuron_indices() { - if !retrieve_neuron2graph_neuron(model_handle, &data_type, path, neuron_index) - .await? - { + if !retrieve_neuron2graph_neuron(model_handle, &data_type, path, neuron_index).await? { num_missing += 1 } @@ -117,6 +127,8 @@ pub async fn retrieve_neuron2graph( neuron_index.flat_index(layer_size) ); } - println!("\rStored all {num_total_neurons} neuron graphs. {num_missing} were missing. "); + println!( + "\rStored all {num_total_neurons} neuron graphs. {num_missing} were missing. " + ); Ok(()) } diff --git a/src/data/retrieve/neuron_explainer.rs b/src/data/retrieve/neuron_explainer.rs index 0a649b9..940c80c 100644 --- a/src/data/retrieve/neuron_explainer.rs +++ b/src/data/retrieve/neuron_explainer.rs @@ -38,7 +38,10 @@ pub fn model_url(model_name: &str, index: NeuronIndex) -> Result { match model_name { "gpt2-small" => Ok(small_url(index)), "gpt2-xl" => Ok(xl_url(index)), - _ => bail!("Neuron explainer retrieval only available for models 'gpt2-small' and 'gpt2-xl'. Given model name: {model_name}"), + _ => bail!( + "Neuron explainer retrieval only available for models 'gpt2-small' and 'gpt2-xl'. \ + Given model name: {model_name}" + ), } } @@ -82,11 +85,15 @@ async fn fetch( Ok(result) => break Some(result), Err(err) => { if retries == RETRY_LIMIT { - log::error!("Failed to fetch neuron explainer data for neuron {index} after {retries} retries. Error: {err}"); + log::error!( + "Failed to fetch neuron explainer data for neuron {index} \ + after {retries} retries. Error: {err}" + ); break None; } log::error!( - "Failed to fetch neuron explainer data for neuron {index}. Retrying...", + "Failed to fetch neuron explainer data for neuron {index}. \ + Retrying...", ); log::error!("Error: {err}"); retries += 1; @@ -156,15 +163,31 @@ pub async fn retrieve_neuron_explainer_small(model_handle: &mut ModelHandle) -> println!("Retrieving neuron explainer data for GPT-2 small."); let num_layers = model_handle.metadata().num_layers; let layer_size = model_handle.metadata().layer_size; - ensure!(num_layers == SMALL_NUM_LAYERS, "Model has wrong number of layers. GPT-2 small has {SMALL_NUM_LAYERS} but model has {num_layers}."); - ensure!(layer_size == SMALL_LAYER_SIZE, "Model has wrong layer size. GPT-2 small has {SMALL_LAYER_SIZE} neurons per layer, but model has {layer_size}"); + ensure!( + num_layers == SMALL_NUM_LAYERS, + "Model has wrong number of layers. GPT-2 small has {SMALL_NUM_LAYERS} but model has \ + {num_layers}." + ); + ensure!( + layer_size == SMALL_LAYER_SIZE, + "Model has wrong layer size. GPT-2 small has {SMALL_LAYER_SIZE} neurons per layer, but \ + model has {layer_size}" + ); fetch_to_database(model_handle, small_url).await } pub async fn retrieve_neuron_explainer_xl(model_handle: &mut ModelHandle) -> Result<()> { let num_layers = model_handle.metadata().num_layers; let layer_size = model_handle.metadata().layer_size; - ensure!(num_layers == XL_NUM_LAYERS, "Model has wrong number of layers. GPT-2 XL has {XL_NUM_LAYERS} but model has {num_layers}."); - ensure!(layer_size == XL_LAYER_SIZE, "Model has wrong layer size. GPT-2 XL has {XL_LAYER_SIZE} neurons per layer, but model has {layer_size}"); + ensure!( + num_layers == XL_NUM_LAYERS, + "Model has wrong number of layers. GPT-2 XL has {XL_NUM_LAYERS} but model has \ + {num_layers}." + ); + ensure!( + layer_size == XL_LAYER_SIZE, + "Model has wrong layer size. GPT-2 XL has {XL_LAYER_SIZE} neurons per layer, but model \ + has {layer_size}" + ); fetch_to_database(model_handle, xl_url).await } diff --git a/src/data/retrieve/neuron_store.rs b/src/data/retrieve/neuron_store.rs index 95a66d9..c808d4e 100644 --- a/src/data/retrieve/neuron_store.rs +++ b/src/data/retrieve/neuron_store.rs @@ -1,12 +1,12 @@ use std::path::Path; +use anyhow::{bail, Context, Result}; + use crate::data::{ data_types::DataType, neuron_store::NeuronStoreRaw, DataTypeHandle, Database, ModelHandle, NeuronStore, }; -use anyhow::{bail, Context, Result}; - pub async fn store_similar_neurons( model_handle: &mut ModelHandle, data_type_handle: &DataTypeHandle, @@ -16,29 +16,51 @@ pub async fn store_similar_neurons( let model_name = model_handle.name().to_owned(); let model_name = model_name.as_str(); print!("Calculating neuron similarities..."); - let neuron_relatedness = neuron_store.neuron_similarity(similarity_threshold).with_context(|| - format!("Failed to calculate neuron similarities for model '{model_name}'.",) - )?; + let neuron_relatedness = neuron_store + .neuron_similarity(similarity_threshold) + .with_context(|| { + format!("Failed to calculate neuron similarities for model '{model_name}'.",) + })?; let num_neurons = model_handle.metadata().num_total_neurons; let mut num_completed = 0; print!("Adding neuron similarities to database: {num_completed}/{num_neurons}",); for neuron_index in model_handle.metadata().neuron_indices() { - let similar_neurons = - neuron_relatedness.similar_neurons(neuron_index).with_context(|| - format!("Failed to get similar neurons for neuron {neuron_index} in model '{model_name}'.") - )?; - let data = similar_neurons.to_binary().with_context(|| - format!("Failed to serialize similar neuron vector for neuron {neuron_index} in model {model_name}.") - )?; - model_handle.add_neuron_data(data_type_handle, neuron_index.layer, neuron_index.neuron, data).await.with_context(|| - format!("Failed to add similar neuron vector for neuron {neuron_index} in model {model_name} to database.") - )?; + let similar_neurons = neuron_relatedness + .similar_neurons(neuron_index) + .with_context(|| { + format!( + "Failed to get similar neurons for neuron {neuron_index} in model \ + '{model_name}'." + ) + })?; + let data = similar_neurons.to_binary().with_context(|| { + format!( + "Failed to serialize similar neuron vector for neuron {neuron_index} in model \ + {model_name}." + ) + })?; + model_handle + .add_neuron_data( + data_type_handle, + neuron_index.layer, + neuron_index.neuron, + data, + ) + .await + .with_context(|| { + format!( + "Failed to add similar neuron vector for neuron {neuron_index} in model \ + {model_name} to database." + ) + })?; num_completed += 1; print!("\rAdding neuron similarities to database: {num_completed}/{num_neurons}",); } - println!("\rAdding neuron similarities to database: {num_completed}/{num_neurons} ",); + println!( + "\rAdding neuron similarities to database: {num_completed}/{num_neurons} ", + ); Ok(()) } diff --git a/src/data/retrieve/neuroscope.rs b/src/data/retrieve/neuroscope.rs index 01b2a2b..65ff31a 100644 --- a/src/data/retrieve/neuroscope.rs +++ b/src/data/retrieve/neuroscope.rs @@ -1,5 +1,10 @@ use std::{panic, sync::Arc, time::Duration}; +use anyhow::{bail, Context, Result}; +use reqwest::Client; +use scraper::{Html, Selector}; +use tokio::{sync::Semaphore, task::JoinSet}; + use crate::{ data::{ data_objects::{ @@ -12,11 +17,6 @@ use crate::{ Index, }; -use anyhow::{bail, Context, Result}; -use reqwest::Client; -use scraper::{Html, Selector}; -use tokio::{sync::Semaphore, task::JoinSet}; - const NEUROSCOPE_BASE_URL: &str = "https://neuroscope.io/"; const RETRY_LIMIT: u32 = 5; @@ -52,14 +52,30 @@ async fn scrape_neuron_page_to_database( NeuroscopeNeuronPage::from_binary(page_data)? } else { let page = scrape_neuron_page(model.name(), neuron_index).await?; - model.add_neuron_data( data_type, neuron_index.layer, neuron_index.neuron, page.to_binary()?).await.with_context(|| format!("Failed to write neuroscope page for neuron {neuron_index} in model '{model_name}' to database.", model_name = model.name()))?; + model + .add_neuron_data( + data_type, + neuron_index.layer, + neuron_index.neuron, + page.to_binary()?, + ) + .await + .with_context(|| { + format!( + "Failed to write neuroscope page for neuron {neuron_index} in model \ + '{model_name}' to database.", + model_name = model.name() + ) + })?; page }; let model_name = model.name(); - let first_text = page - .texts() - .first() - .with_context(|| format!("Failed to get first text from neuroscope page for neuron {neuron_index} in model '{model_name}'."))?; + let first_text = page.texts().first().with_context(|| { + format!( + "Failed to get first text from neuroscope page for neuron {neuron_index} in model \ + '{model_name}'." + ) + })?; let activation_range = first_text.max_activation() - first_text.min_activation(); Ok(activation_range) @@ -96,23 +112,28 @@ async fn scrape_layer_to_database( join_set.spawn(async move { let permit = semaphore.acquire_owned().await.unwrap(); let mut retries = 0; - let result = loop { - match scrape_neuron_page_to_database(&mut model, &data_type, neuron_index).await { - Ok(result) => break result, - Err(err) => { - if retries == RETRY_LIMIT { - log::error!("Failed to fetch neuroscope page for neuron {neuron_index} after {retries} retries. Error: {err:?}"); - return Err(err); - } + let result = loop { + match scrape_neuron_page_to_database(&mut model, &data_type, neuron_index).await + { + Ok(result) => break result, + Err(err) => { + if retries == RETRY_LIMIT { log::error!( - "Failed to fetch neuroscope page for neuron {neuron_index}. Retrying...", + "Failed to fetch neuroscope page for neuron {neuron_index} \ + after {retries} retries. Error: {err:?}" ); - log::error!("Error: {err:?}"); - retries += 1; - tokio::time::sleep(Duration::from_millis(5)).await; + return Err(err); } + log::error!( + "Failed to fetch neuroscope page for neuron {neuron_index}. \ + Retrying...", + ); + log::error!("Error: {err:?}"); + retries += 1; + tokio::time::sleep(Duration::from_millis(5)).await; } - }; + } + }; drop(permit); Ok::<_, anyhow::Error>((neuron_index, result)) }); diff --git a/src/index.rs b/src/index.rs index a2c1bdf..6f1b878 100644 --- a/src/index.rs +++ b/src/index.rs @@ -1,6 +1,7 @@ -use crate::data::{Metadata, NeuronIndex}; use anyhow::{anyhow, Result}; +use crate::data::{Metadata, NeuronIndex}; + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum Index { Model, @@ -31,13 +32,19 @@ impl Index { let layer_size = metadata.layer_size; match self { - Self::Layer(layer_index) | Self::Neuron(layer_index, _) if layer_index >= num_layers => Err(anyhow!( - "Layer index is {layer_index} but model '{model_name}' only has {num_layers} layers." - )), + Self::Layer(layer_index) | Self::Neuron(layer_index, _) + if layer_index >= num_layers => + { + Err(anyhow!( + "Layer index is {layer_index} but model '{model_name}' only has {num_layers} \ + layers." + )) + } Self::Neuron(_, neuron_index) if neuron_index >= layer_size => Err(anyhow!( - "Neuron index is {neuron_index} but model '{model_name}' only has {layer_size} neurons per layer." + "Neuron index is {neuron_index} but model '{model_name}' only has {layer_size} \ + neurons per layer." )), - _ => Ok(()) + _ => Ok(()), } } diff --git a/src/logging.rs b/src/logging.rs index 1de199e..6c69565 100644 --- a/src/logging.rs +++ b/src/logging.rs @@ -1,14 +1,15 @@ //! Sets up logging. -use env_logger::{fmt::Color, Logger, Target}; -use log::{Level, LevelFilter, Log}; -use multi_log::MultiLogger; use std::{ fs, io::Write, path::{Path, PathBuf}, }; +use env_logger::{fmt::Color, Logger, Target}; +use log::{Level, LevelFilter, Log}; +use multi_log::MultiLogger; + use crate::cli; fn log_file_path

(dir: P, index: u32) -> PathBuf diff --git a/src/pyo3/database.rs b/src/pyo3/database.rs index efe9da6..0284b06 100644 --- a/src/pyo3/database.rs +++ b/src/pyo3/database.rs @@ -2,16 +2,15 @@ use anyhow::Context; use pyo3::prelude::*; use tokio::runtime::Runtime; -use crate::{ - data::{data_types::DataType, Database}, - server::Service, -}; - use super::{ data_type::PyDataType, data_type_handle::PyDataTypeHandle, model_handle::PyModelHandle, model_metadata::PyModelMetadata, service_handle::PyServiceHandle, service_provider::PyServiceProvider, }; +use crate::{ + data::{data_types::DataType, Database}, + server::Service, +}; #[pyclass(name = "Database")] pub struct PyDatabase { @@ -77,8 +76,9 @@ impl PyDatabase { DataType::Json => {} data_type => { return Err(PyErr::new::(format!( - "Objects of data type {data_type:?} should be added with the appropriate method.", - ))) + "Objects of data type {data_type:?} should be added with the appropriate \ + method.", + ))) } } let result = Runtime::new() diff --git a/src/pyo3/model_handle.rs b/src/pyo3/model_handle.rs index 1e98dc6..e0a0209 100644 --- a/src/pyo3/model_handle.rs +++ b/src/pyo3/model_handle.rs @@ -2,12 +2,11 @@ use anyhow::{bail, Context}; use pyo3::prelude::*; use tokio::runtime::Runtime; -use crate::data::{retrieve, ModelHandle}; - use super::{ data_type_handle::PyDataTypeHandle, index::PyIndex, model_metadata::PyModelMetadata, service_handle::PyServiceHandle, }; +use crate::data::{retrieve, ModelHandle}; #[pyclass(name = "ModelHandle")] pub struct PyModelHandle { @@ -153,30 +152,33 @@ impl PyModelHandle { Runtime::new() .context("Failed to start async runtime to add JSON data.")? .block_on(async { - if !model.has_data_type(data_type).await.with_context(|| + if !model.has_data_type(data_type).await.with_context(|| { format!( - "Failed to check whether model '{model_name}' has data object '{data_type_name}'.", - model_name=model.name(), - data_type_name=data_type.name() + "Failed to check whether model '{model_name}' has data object \ + '{data_type_name}'.", + model_name = model.name(), + data_type_name = data_type.name() + ) + })? { + bail!( + "Cannot add JSON data to data object '{data_type_name}' for {index} in \ + model '{model_name}' because model does not have data object.", + data_type_name = data_type.name(), + index = index.index.error_string(), + model_name = model.name() ) - )? { - bail!("Cannot add JSON data to data object '{data_type_name}' for {index} in model '{model_name}' because model does not have data object.", - data_type_name=data_type.name(), - index=index.index.error_string(), - model_name=model.name()) } - retrieve::json::store_json_data( - model, - data_type, - index.into(), - json, - ) - .await.with_context(|| format!( - "Failed to add JSON data to '{data_type_name}' for {index} in model '{model_name}'.", - data_type_name=data_type.name(), - index=index.index.error_string(), - model_name=model.name() - )) + retrieve::json::store_json_data(model, data_type, index.into(), json) + .await + .with_context(|| { + format!( + "Failed to add JSON data to '{data_type_name}' for {index} in model \ + '{model_name}'.", + data_type_name = data_type.name(), + index = index.index.error_string(), + model_name = model.name() + ) + }) })?; Ok(()) } diff --git a/src/pyo3/service_provider.rs b/src/pyo3/service_provider.rs index 9f7d1c9..aa6af19 100644 --- a/src/pyo3/service_provider.rs +++ b/src/pyo3/service_provider.rs @@ -1,8 +1,7 @@ use pyo3::prelude::*; -use crate::server::ServiceProvider; - use super::data_type_handle::PyDataTypeHandle; +use crate::server::ServiceProvider; #[pyclass(name = "ServiceProvider")] #[derive(Clone)] diff --git a/src/server/response.rs b/src/server/response.rs index c49091d..d9f98fd 100644 --- a/src/server/response.rs +++ b/src/server/response.rs @@ -6,14 +6,13 @@ use anyhow::{anyhow, Result}; use reqwest::StatusCode; use serde_json::json; +use super::{RequestType, Service}; use crate::{ data::{data_objects::MetadataObject, Database, ModelHandle, ServiceHandle}, server::State, Index, }; -use super::{RequestType, Service}; - pub enum Body { Json(serde_json::Value), Binary(Vec), @@ -331,7 +330,10 @@ pub async fn model( Ok(request_type) => request_type, Err(error) => return Response::error(error, StatusCode::BAD_REQUEST), }; - log::debug!("Received {request_type_string} request for service '{service_name}' for model '{model_name}'."); + log::debug!( + "Received {request_type_string} request for service '{service_name}' for model \ + '{model_name}'." + ); response( state, query.deref(), @@ -370,7 +372,10 @@ pub async fn layer( Ok(request_type) => request_type, Err(error) => return Response::error(error, StatusCode::BAD_REQUEST), }; - log::debug!("Received {request_type_string} request for service '{service_name}' for layer {layer_index} in model '{model_name}'."); + log::debug!( + "Received {request_type_string} request for service '{service_name}' for layer \ + {layer_index} in model '{model_name}'." + ); response( state, query.deref(), @@ -411,7 +416,10 @@ pub async fn neuron( Ok(request_type) => request_type, Err(error) => return Response::error(error, StatusCode::BAD_REQUEST), }; - log::debug!("Received {request_type_string} request for service '{service_name}' for neuron l{layer_index}n{neuron_index} in model '{model_name}'."); + log::debug!( + "Received {request_type_string} request for service '{service_name}' for neuron \ + l{layer_index}n{neuron_index} in model '{model_name}'." + ); response( state, query.deref(), @@ -496,7 +504,10 @@ pub async fn all_neuron( query: web::Query, ) -> impl Responder { let (model_name, layer_index, neuron_index) = indices.into_inner(); - log::debug!("Received request for all services for neuron l{layer_index}n{neuron_index} in model '{model_name}'."); + log::debug!( + "Received request for all services for neuron l{layer_index}n{neuron_index} in model \ + '{model_name}'." + ); all_response( state, query, diff --git a/src/server/service.rs b/src/server/service.rs index 28c401b..19e2eb1 100644 --- a/src/server/service.rs +++ b/src/server/service.rs @@ -1,13 +1,12 @@ use anyhow::Result; use serde::{Deserialize, Serialize}; +use super::{response::Body, RequestType, ServiceProvider, State}; use crate::{ data::{DataTypeHandle, Database, ModelHandle}, Index, }; -use super::{response::Body, RequestType, ServiceProvider, State}; - #[derive(Clone, Serialize, Deserialize)] pub struct Service { pub name: String, diff --git a/src/server/service_providers/json.rs b/src/server/service_providers/json.rs index 0bfff71..e97f402 100644 --- a/src/server/service_providers/json.rs +++ b/src/server/service_providers/json.rs @@ -2,12 +2,12 @@ use anyhow::{bail, Context, Result}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; -use crate::data::data_types::Json as JsonData; -use crate::data::{DataTypeHandle, Database, ModelHandle}; -use crate::server::State; -use crate::Index; - use super::service_provider::ServiceProviderTrait; +use crate::{ + data::{data_types::Json as JsonData, DataTypeHandle, Database, ModelHandle}, + server::State, + Index, +}; #[derive(Clone, Serialize, Deserialize)] pub struct Json(String); @@ -45,9 +45,17 @@ async fn page( let model_name = model.name(); let json_object = data_type(state.database(), model, data_type_name).await?; let query = query.as_object().context("Query is not an object.")?; - let json = json_object.page(index).await.with_context(|| - format!("Failed to get json data object '{data_type_name}' for {index} of model '{model_name}'.", index = index.error_string()) - )?.value; + let json = json_object + .page(index) + .await + .with_context(|| { + format!( + "Failed to get json data object '{data_type_name}' for {index} of model \ + '{model_name}'.", + index = index.error_string() + ) + })? + .value; if query.is_empty() { Ok(json) } else if let Some(json_index) = query.get("get") { @@ -57,16 +65,32 @@ async fn page( if let Some(value) = json.get_mut(json_index) { Some(value) } else { - let int_index = json_index.parse::().with_context(|| format!("No field '{json_index}' exists and the index is not an integer."))?; + let int_index = json_index.parse::().with_context(|| { + format!("No field '{json_index}' exists and the index is not an integer.") + })?; json.get_mut(int_index) } } - serde_json::Value::Number(json_index) => json.get_mut(json_index.as_u64().context("Query 'get' field is not a u64.")? as usize), + serde_json::Value::Number(json_index) => json.get_mut( + json_index + .as_u64() + .context("Query 'get' field is not a u64.")? as usize, + ), _ => bail!("Query 'get' field is not a string or a number."), - }.with_context(|| format!("Failed to get json value '{json_index}' for {index} of model '{model_name}' and data object '{data_type_name}'.", index = index.error_string())) + } + .with_context(|| { + format!( + "Failed to get json value '{json_index}' for {index} of model '{model_name}' and \ + data object '{data_type_name}'.", + index = index.error_string() + ) + }) .map(serde_json::Value::take) } else { - bail!("Invalid query for json service. Query must be empty or contain a 'get' field. Query: {query:?}") + bail!( + "Invalid query for json service. Query must be empty or contain a 'get' field. Query: \ + {query:?}" + ) } } @@ -78,10 +102,12 @@ impl ServiceProviderTrait for Json { async fn required_data_types(&self, database: &Database) -> Result> { let Self(ref data_type_name) = self; - let data_type = database - .data_type(data_type_name) - .await? - .with_context(|| format!("No data object with name '{data_type_name}'. This should have been checked when the service was created."))?; + let data_type = database.data_type(data_type_name).await?.with_context(|| { + format!( + "No data object with name '{data_type_name}'. This should have been checked when \ + the service was created." + ) + })?; Ok(vec![data_type]) } diff --git a/src/server/service_providers/metadata.rs b/src/server/service_providers/metadata.rs index 171a105..b2aeb3f 100644 --- a/src/server/service_providers/metadata.rs +++ b/src/server/service_providers/metadata.rs @@ -2,13 +2,12 @@ use anyhow::Result; use async_trait::async_trait; use serde::{Deserialize, Serialize}; +use super::ServiceProviderTrait; use crate::{ data::{data_objects::MetadataObject, DataTypeHandle, Database, ModelHandle}, server::State, }; -use super::ServiceProviderTrait; - #[derive(Clone, Serialize, Deserialize)] pub struct Metadata; diff --git a/src/server/service_providers/neuron2graph.rs b/src/server/service_providers/neuron2graph.rs index 52df472..75f947a 100644 --- a/src/server/service_providers/neuron2graph.rs +++ b/src/server/service_providers/neuron2graph.rs @@ -1,8 +1,8 @@ use anyhow::{Context, Result}; use async_trait::async_trait; - use serde::{Deserialize, Serialize}; +use super::service_provider::{NoData, ServiceProviderTrait}; use crate::{ data::{ data_objects::Neuron2GraphData as Neuron2GraphDataObject, @@ -12,8 +12,6 @@ use crate::{ server::State, }; -use super::service_provider::{NoData, ServiceProviderTrait}; - #[derive(Clone, Serialize, Deserialize)] pub struct Neuron2Graph; @@ -62,11 +60,21 @@ impl ServiceProviderTrait for Neuron2Graph { let n2g_data_type = database .data_type(n2g_object_name) .await? - .with_context(|| format!("No data object with name '{n2g_object_name}'. This should have been checked when service was created."))?; + .with_context(|| { + format!( + "No data object with name '{n2g_object_name}'. This should have been checked \ + when service was created." + ) + })?; let neuron_store_data_type = database .data_type(neuron_store_object_name) .await? - .with_context(|| format!("No data object with name '{neuron_store_object_name}'. This should have been checked when service was created."))?; + .with_context(|| { + format!( + "No data object with name '{neuron_store_object_name}'. This should have been \ + checked when service was created." + ) + })?; Ok(vec![n2g_data_type, neuron_store_data_type]) } diff --git a/src/server/service_providers/neuron2graph_search.rs b/src/server/service_providers/neuron2graph_search.rs index 48f4428..78161e4 100644 --- a/src/server/service_providers/neuron2graph_search.rs +++ b/src/server/service_providers/neuron2graph_search.rs @@ -4,6 +4,7 @@ use anyhow::{Context, Result}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; +use super::service_provider::{NoData, ServiceProviderTrait}; use crate::{ data::{ data_types::NeuronStore as NeuronStoreObject, DataTypeHandle, Database, ModelHandle, @@ -12,8 +13,6 @@ use crate::{ server::State, }; -use super::service_provider::{NoData, ServiceProviderTrait}; - #[derive(Clone, Serialize, Deserialize)] pub struct Neuron2GraphSearch; @@ -24,9 +23,14 @@ impl ServiceProviderTrait for Neuron2GraphSearch { type NeuronPageObject = NoData; async fn required_data_types(&self, database: &Database) -> Result> { - database.data_type("neuron_store").await?.context( - "No data object named 'neuron_store' in database. This should have been checked when service was created." - ).map(|data_type| vec![data_type]) + database + .data_type("neuron_store") + .await? + .context( + "No data object named 'neuron_store' in database. This should have been checked \ + when service was created.", + ) + .map(|data_type| vec![data_type]) } async fn model_object( @@ -47,7 +51,8 @@ impl ServiceProviderTrait for Neuron2GraphSearch { .await .with_context(|| { format!( - "Model '{}' has no 'neuron_store' data object. This should have been checked earlier.", + "Model '{}' has no 'neuron_store' data object. This should have been checked \ + earlier.", model.name() ) })?; diff --git a/src/server/service_providers/neuron_explainer.rs b/src/server/service_providers/neuron_explainer.rs index 5c019d4..2567fcb 100644 --- a/src/server/service_providers/neuron_explainer.rs +++ b/src/server/service_providers/neuron_explainer.rs @@ -3,6 +3,7 @@ use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; +use super::service_provider::{NoData, ServiceProviderTrait}; use crate::{ data::{ data_objects::NeuronExplainerPage, data_types::NeuronExplainer as NeuronExplainerData, @@ -11,8 +12,6 @@ use crate::{ server::State, }; -use super::service_provider::{NoData, ServiceProviderTrait}; - #[derive(Clone, Serialize, Deserialize)] pub struct NeuronExplainer; @@ -39,10 +38,12 @@ impl ServiceProviderTrait for NeuronExplainer { async fn required_data_types(&self, database: &Database) -> Result> { let data_type_name = "neuron_explainer"; - let data_type = database - .data_type(data_type_name) - .await? - .with_context(|| format!("No data object with name '{data_type_name}'. This should have been checked when the service was created."))?; + let data_type = database.data_type(data_type_name).await?.with_context(|| { + format!( + "No data object with name '{data_type_name}'. This should have been checked when \ + the service was created." + ) + })?; Ok(vec![data_type]) } @@ -66,11 +67,18 @@ impl ServiceProviderTrait for NeuronExplainer { { page } else { - neuron_explainer::fetch_neuron(&Client::new(), neuron_explainer::model_url(model.name(), index)?).await.with_context(|| - format!("No neuron explainer page exists for neuron {index} in model '{model_name}' and fetching from source failed.", - model_name = model.name() - ) - )? + neuron_explainer::fetch_neuron( + &Client::new(), + neuron_explainer::model_url(model.name(), index)?, + ) + .await + .with_context(|| { + format!( + "No neuron explainer page exists for neuron {index} in model '{model_name}' \ + and fetching from source failed.", + model_name = model.name() + ) + })? }; Ok(page) } diff --git a/src/server/service_providers/neuroscope.rs b/src/server/service_providers/neuroscope.rs index 5af9d96..4e2ac78 100644 --- a/src/server/service_providers/neuroscope.rs +++ b/src/server/service_providers/neuroscope.rs @@ -3,13 +3,16 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use serde_json::json; -use crate::data::data_objects::{NeuroscopeLayerPage, NeuroscopeModelPage, NeuroscopeNeuronPage}; -use crate::data::data_types::Neuroscope as NeuroscopeData; -use crate::data::retrieve::neuroscope::scrape_neuron_page; -use crate::data::{DataTypeHandle, Database, ModelHandle, NeuronIndex}; -use crate::server::State; - use super::service_provider::ServiceProviderTrait; +use crate::{ + data::{ + data_objects::{NeuroscopeLayerPage, NeuroscopeModelPage, NeuroscopeNeuronPage}, + data_types::Neuroscope as NeuroscopeData, + retrieve::neuroscope::scrape_neuron_page, + DataTypeHandle, Database, ModelHandle, NeuronIndex, + }, + server::State, +}; #[derive(Clone, Serialize, Deserialize)] pub struct Neuroscope; @@ -37,10 +40,12 @@ impl ServiceProviderTrait for Neuroscope { async fn required_data_types(&self, database: &Database) -> Result> { let data_type_name = "neuroscope"; - let data_type = database - .data_type(data_type_name) - .await? - .with_context(|| format!("No data object with name '{data_type_name}'. This should have been checked when the service was created."))?; + let data_type = database.data_type(data_type_name).await?.with_context(|| { + format!( + "No data object with name '{data_type_name}'. This should have been checked when \ + the service was created." + ) + })?; Ok(vec![data_type]) } @@ -96,15 +101,29 @@ impl ServiceProviderTrait for Neuroscope { data_type(state, model) .await? .neuron_page(layer_index, neuron_index) - .await.with_context(|| format!( - "Failed to get neuroscope neuron page for neuron {neuron_index} in model '{model_name}'.", - neuron_index = NeuronIndex {layer: layer_index, neuron: neuron_index}, - model_name = model.name() - ))?.with_context(|| format!( - "Failed to get neuroscope neuron page for neuron {neuron_index} in model '{model_name}'.", - neuron_index = NeuronIndex {layer: layer_index, neuron: neuron_index}, - model_name = model.name() - )) + .await + .with_context(|| { + format!( + "Failed to get neuroscope neuron page for neuron {neuron_index} in model \ + '{model_name}'.", + neuron_index = NeuronIndex { + layer: layer_index, + neuron: neuron_index + }, + model_name = model.name() + ) + })? + .with_context(|| { + format!( + "Failed to get neuroscope neuron page for neuron {neuron_index} in model \ + '{model_name}'.", + neuron_index = NeuronIndex { + layer: layer_index, + neuron: neuron_index + }, + model_name = model.name() + ) + }) } async fn layer_json( @@ -138,11 +157,21 @@ impl ServiceProviderTrait for Neuroscope { { page } else { - scrape_neuron_page(model.name(), NeuronIndex{layer: layer_index, neuron: neuron_index}).await.with_context(|| - format!("No neuroscope page exists for neuron l{layer_index}n{neuron_index} in model '{model_name}' and fetching from source failed.", - model_name = model.name() - ) - )? + scrape_neuron_page( + model.name(), + NeuronIndex { + layer: layer_index, + neuron: neuron_index, + }, + ) + .await + .with_context(|| { + format!( + "No neuroscope page exists for neuron l{layer_index}n{neuron_index} in model \ + '{model_name}' and fetching from source failed.", + model_name = model.name() + ) + })? }; Ok(json!(page)) } diff --git a/src/server/start.rs b/src/server/start.rs index 19b7c8b..08cc7b3 100644 --- a/src/server/start.rs +++ b/src/server/start.rs @@ -3,7 +3,6 @@ use actix_web::{ web, App, HttpServer, }; use anyhow::{bail, Result}; - use utoipa_redoc::{Redoc, Servable}; use crate::{