Skip to content

Commit

Permalink
some error refactoring, add is_writable
Browse files Browse the repository at this point in the history
  • Loading branch information
arilotter committed Nov 1, 2024
1 parent 338864d commit 7764fac
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 54 deletions.
25 changes: 21 additions & 4 deletions examples/upload.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,38 @@
use std::time::Instant;

use hf_hub::{api::tokio::ApiBuilder, Repo};
use hf_hub::{
api::tokio::{ApiBuilder, ApiError},
Repo,
};
use rand::Rng;

const ONE_MB: usize = 1024 * 1024;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init();
let token =
std::env::var("HF_TOKEN").map_err(|_| format!("HF_TOKEN environment variable not set"))?;
let token = std::env::var("HF_TOKEN")
.map_err(|_| "HF_TOKEN environment variable not set".to_string())?;
let hf_repo = std::env::var("HF_REPO")
.map_err(|_| format!("HF_REPO environment variable not set, e.g. apyh/gronk"))?;
.map_err(|_| "HF_REPO environment variable not set, e.g. apyh/gronk".to_string())?;

let api = ApiBuilder::new().with_token(Some(token)).build()?;
let repo = Repo::model(hf_repo);
let api_repo = api.repo(repo);

let exists = api_repo.exists().await;
if !exists {
return Err(ApiError::GatedRepoError("repo does not exist".to_string()).into());
} else {
println!("repo exists!");
}

let is_writable = api_repo.is_writable().await;
if !is_writable {
return Err(ApiError::GatedRepoError("repo is not writable".to_string()).into());
} else {
println!("repo is writable!");
}
let files = [
(
format!("im a tiny file {:?}", Instant::now())
Expand Down
95 changes: 82 additions & 13 deletions src/api/tokio/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use super::RepoInfo;
use crate::{Cache, Repo, RepoType};
use http::StatusCode;
use rand::Rng;
use regex::Regex;
use reqwest::{
header::{
HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue, ToStrError, AUTHORIZATION,
Expand Down Expand Up @@ -33,13 +35,14 @@ const NAME: &str = env!("CARGO_PKG_NAME");
/// which can be useful for debugging and error reporting when HTTP requests fail.
#[derive(Debug)]
pub struct ReqwestErrorWithBody {
url: String,
error: ReqwestError,
body: Result<String, ReqwestError>,
}

impl Display for ReqwestErrorWithBody {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Request error:",)?;
writeln!(f, "Request error: {}", self.url)?;
writeln!(f, "{}", self.error)?;
match &self.body {
Ok(body) => {
Expand All @@ -65,7 +68,7 @@ impl std::error::Error for ReqwestErrorWithBody {}
/// # Examples
///
/// ```
/// use hf_hub::api::tokio::ReqwestBadResponse;
/// use hf_hub::api::tokio::HfBadResponse;
///
/// async fn example() -> Result<(), Box<dyn std::error::Error>> {
/// let response = reqwest::get("https://api.example.com/data").await?;
Expand All @@ -81,29 +84,75 @@ impl std::error::Error for ReqwestErrorWithBody {}
/// # Error Handling
///
/// - If the response status is successful (2xx), returns `Ok(Response)`
/// - If the response status indicates an error (4xx, 5xx), returns `Err(ReqwestErrorWithBody)`
/// - If the response status indicates an error (4xx, 5xx), returns `Err(ApiError)`
/// containing both the original error and the response body text
pub trait ReqwestBadResponse {
pub trait HfBadResponse {
/// Checks if the response status code indicates an error, and if so, captures the response body
/// along with the error details.
///
/// Returns a Future that resolves to:
/// - `Ok(Response)` if the status code is successful
/// - `Err(ReqwestErrorWithBody)` if the status code indicates an error
fn maybe_err(self) -> impl Future<Output = Result<Self, ReqwestErrorWithBody>>
/// - `Err(ApiError)` if the status code indicates an error
fn maybe_hf_err(self) -> impl Future<Output = Result<Self, ApiError>>
where
Self: Sized;
}

impl ReqwestBadResponse for reqwest::Response {
async fn maybe_err(self) -> Result<Self, ReqwestErrorWithBody>
lazy_static::lazy_static! {
static ref REPO_API_REGEX: Regex = Regex::new(
r#"(?x)
# staging or production endpoint
^https://[^/]+
(
# on /api/repo_type/repo_id
/api/(models|datasets|spaces)/(.+)
|
# or /repo_id/resolve/revision/...
/(.+)/resolve/(.+)
)
"#,
).unwrap();
}

impl HfBadResponse for reqwest::Response {
async fn maybe_hf_err(self) -> Result<Self, ApiError>
where
Self: Sized,
{
let error = self.error_for_status_ref();
if let Err(error) = error {
let body = self.text().await;
Err(ReqwestErrorWithBody { body, error })
let hf_error_code = self
.headers()
.get("X-Error-Code")
.and_then(|v| v.to_str().ok());
let hf_error_message = self
.headers()
.get("X-Error-Message")
.and_then(|v| v.to_str().ok());
let url = self.url().to_string();
Err(match (hf_error_code, hf_error_message) {
(Some("RevisionNotFound"), _) => ApiError::RevisionNotFoundError(url),
(Some("EntryNotFound"), _) => ApiError::EntryNotFoundError(url),
(Some("GatedRepo"), _) => ApiError::GatedRepoError(url),
(_, Some("Access to this resource is disabled.")) => {
ApiError::DisabledRepoError(url)
}
// 401 is misleading as it is returned for:
// - private and gated repos if user is not authenticated
// - missing repos
// => for now, we process them as `RepoNotFound` anyway.
// See https://gist.github.com/Wauplin/46c27ad266b15998ce56a6603796f0b9
(Some("RepoNotFound"), _)
if self.status() == StatusCode::UNAUTHORIZED
&& REPO_API_REGEX.is_match(&url) =>
{
ApiError::RepositoryNotFoundError(url)
}
(_, _) => {
let body = self.text().await;
ApiError::RequestErrorWithBody(ReqwestErrorWithBody { url, body, error })
}
})
} else {
Ok(self)
}
Expand All @@ -117,7 +166,7 @@ pub enum ApiError {
#[error("Header {0} is missing")]
MissingHeader(HeaderName),

/// The header exists, but the value is not conform to what the Api expects.
/// The header exists, but the value does not conform to what the Api expects.
#[error("Header {0} is invalid")]
InvalidHeader(HeaderName),

Expand Down Expand Up @@ -158,8 +207,28 @@ pub enum ApiError {
AcquireError(#[from] AcquireError),

/// Bad data from the API
#[error("Invalid Response: {0:?}")]
#[error("Invalid Response: {0}")]
InvalidResponse(String),

/// Repo exists, but the revision / oid doesn't exist.
#[error("Revision Not Found for url: {0}")]
RevisionNotFoundError(String),

/// todo what is this?
#[error("Entry Not Found for url: {0}")]
EntryNotFoundError(String),

/// Repo is gated
#[error("Cannot access gated repo for url: {0}")]
GatedRepoError(String),

/// Repo is disabled
#[error("Cannot access repo - access to resource is disabled for url: {0}")]
DisabledRepoError(String),

/// Repo does not exist for the caller (could be private)
#[error("Repository Not Found for url: {0}")]
RepositoryNotFoundError(String),
}

/// Helper to create [`Api`] with all the options.
Expand Down Expand Up @@ -228,7 +297,7 @@ impl ApiBuilder {
self
}

/// Sets the token to be used in the API
/// Sets the t to be used in the API
pub fn with_token(mut self, token: Option<String>) -> Self {
self.token = token;
self
Expand Down
69 changes: 50 additions & 19 deletions src/api/tokio/repo_info.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::RepoType;

use super::{Api, ApiError, ApiRepo, ReqwestBadResponse};
use super::{Api, ApiError, ApiRepo, HfBadResponse};

#[derive(Debug)]
pub enum RepoInfo {
Expand Down Expand Up @@ -35,6 +35,53 @@ impl ApiRepo {
RepoType::Space => todo!(),
}
}

/// Checks if this repository exists on the Hugging Face Hub.
pub async fn exists(&self) -> bool {
match self.repo_info().await {
Ok(_) => true,
// no access, but it exists
Err(ApiError::GatedRepoError(_)) => true,
Err(ApiError::RepositoryNotFoundError(_)) => false,
Err(_) => false,
}
}

/// Checks if this repository exists and is writable on the Hugging Face Hub.
pub async fn is_writable(&self) -> bool {
if !self.exists().await {
return false;
}
let mut headers = HeaderMap::new();
headers.insert("Content-Type", "application/x-ndjson".parse().unwrap());

let url = format!(
"{}/api/{}s/{}/commit/{}",
self.api.endpoint,
self.repo.repo_type.to_string(),
self.repo.url(),
self.repo.revision
);

let res: Result<StatusCode, ApiError> = (async {
Ok(self
.api
.client
.post(&url)
.headers(headers)
.send()
.await
.map_err(ApiError::from)?
.status())
})
.await;
if let Ok(status) = res {
if status == StatusCode::FORBIDDEN {
return false;
}
}
true
}
}

impl Api {
Expand All @@ -49,23 +96,6 @@ impl Api {
/// revision (`str`, *optional*):
/// The revision of the model repository from which to get the
/// information.
/// timeout (`float`, *optional*):
/// Whether to set a timeout for the request to the Hub.
/// securityStatus (`bool`, *optional*):
/// Whether to retrieve the security status from the model
/// repository as well.
/// files_metadata (`bool`, *optional*):
/// Whether or not to retrieve metadata for files in the repository
/// (size, LFS metadata, etc). Defaults to `False`.
/// expand (`List[ExpandModelProperty_T]`, *optional*):
/// List properties to return in the response. When used, only the properties in the list will be returned.
/// This parameter cannot be used if `securityStatus` or `files_metadata` are passed.
/// Possible values are `"author"`, `"baseModels"`, `"cardData"`, `"childrenModelCount"`, `"config"`, `"createdAt"`, `"disabled"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"gguf"`, `"inference"`, `"lastModified"`, `"library_name"`, `"likes"`, `"mask_token"`, `"model-index"`, `"pipeline_tag"`, `"private"`, `"safetensors"`, `"sha"`, `"siblings"`, `"spaces"`, `"tags"`, `"transformersInfo"`, `"trendingScore"` and `"widgetData"`.
/// token (Union[bool, str, None], optional):
/// A valid user access token (string). Defaults to the locally saved
/// token, which is the recommended method for authentication (see
/// https://huggingface.co/docs/huggingface_hub/quick-start#authentication).
/// To disable authentication, pass `False`.
async fn model_info(
&self,
repo_id: &str,
Expand All @@ -88,7 +118,7 @@ impl Api {
.get(url)
.send()
.await?
.maybe_err()
.maybe_hf_err()
.await?
.json()
.await?;
Expand All @@ -97,6 +127,7 @@ impl Api {
}
}

use http::{HeaderMap, StatusCode};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

Expand Down
19 changes: 6 additions & 13 deletions src/api/tokio/upload/commit_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use tokio::fs::{read_to_string, File};
use tokio::io::{self, AsyncRead, AsyncReadExt, BufReader};

use crate::api::tokio::upload::lfs::lfs_upload;
use crate::api::tokio::{ApiError, ApiRepo, ReqwestBadResponse};
use crate::api::tokio::{ApiError, ApiRepo, HfBadResponse};

use super::commit_info::{CommitInfo, InvalidHfIdError};

Expand Down Expand Up @@ -361,13 +361,7 @@ impl ApiRepo {

// Pre-upload LFS files
let additions = self
.preupload_lfs_files(
additions,
Some(create_pr),
Some(num_threads),
None,
self.api.endpoint.clone(),
)
.preupload_lfs_files(additions, Some(create_pr), Some(num_threads), None)
.await
.map_err(CommitError::Api)?;

Expand Down Expand Up @@ -461,7 +455,7 @@ impl ApiRepo {
.send()
.await
.map_err(ApiError::from)?
.maybe_err()
.maybe_hf_err()
.await
.map_err(ApiError::from)?;

Expand Down Expand Up @@ -501,7 +495,6 @@ impl ApiRepo {
create_pr: Option<bool>,
num_threads: Option<usize>,
gitignore_content: Option<String>,
endpoint: String,
) -> Result<Vec<CommitOperationAdd>, ApiError> {
// Set default values
let create_pr = create_pr.unwrap_or(false);
Expand Down Expand Up @@ -551,7 +544,7 @@ impl ApiRepo {

// Upload LFS files
let uploaded_lfs_files = self
.upload_lfs_files(new_lfs_additions_to_upload, num_threads, endpoint)
.upload_lfs_files(new_lfs_additions_to_upload, num_threads)
.await?;
Ok(small_files.into_iter().chain(uploaded_lfs_files).collect())
}
Expand Down Expand Up @@ -593,7 +586,7 @@ impl ApiRepo {
.json(&payload)
.send()
.await?
.maybe_err()
.maybe_hf_err()
.await?
.json()
.await?;
Expand Down Expand Up @@ -637,7 +630,6 @@ impl ApiRepo {
&self,
additions: Vec<CommitOperationAdd>,
num_threads: usize,
endpoint: String,
) -> Result<Vec<CommitOperationAdd>, ApiError> {
// Step 1: Retrieve upload instructions from LFS batch endpoint
let mut batch_objects = Vec::new();
Expand Down Expand Up @@ -700,6 +692,7 @@ impl ApiRepo {
let s3_client = reqwest::Client::new();

// Step 3: Upload files concurrently
let endpoint = self.api.endpoint.clone();
let upload_futures: Vec<_> = filtered_actions
.into_iter()
.map(|batch_action| {
Expand Down
Loading

0 comments on commit 7764fac

Please sign in to comment.