diff --git a/src/client/builder.rs b/src/client/builder.rs index 6d2abc7..035ed10 100644 --- a/src/client/builder.rs +++ b/src/client/builder.rs @@ -1,4 +1,6 @@ -use super::{token::AuthorizationBuilder, Authorization, AuthorizationKind, Osu, OsuRef, Token}; +use super::{ + token::AuthorizationBuilder, Authorization, AuthorizationKind, Osu, OsuRef, Scopes, Token, +}; use crate::{error::OsuError, OsuResult}; use hyper::client::Builder; @@ -63,11 +65,12 @@ impl OsuBuilder { let auth_kind = match self.auth { Some(AuthorizationBuilder::Kind(kind)) => kind, #[cfg(feature = "local_oauth")] - Some(AuthorizationBuilder::LocalOauth { redirect_uri }) => { - AuthorizationBuilder::perform_local_oauth(redirect_uri, client_id) - .await - .map(AuthorizationKind::User)? - } + Some(AuthorizationBuilder::LocalOauth { + redirect_uri, + scopes, + }) => AuthorizationBuilder::perform_local_oauth(redirect_uri, client_id, scopes) + .await + .map(AuthorizationKind::User)?, None => AuthorizationKind::default(), }; @@ -91,7 +94,7 @@ impl OsuBuilder { let inner = Arc::new(OsuRef { client_id, - client_secret, + client_secret: client_secret.into_boxed_str(), http, ratelimiter, timeout: self.timeout, @@ -159,9 +162,14 @@ impl OsuBuilder { /// [`with_authorization`]: OsuBuilder::with_authorization #[cfg(feature = "local_oauth")] #[cfg_attr(docsrs, doc(cfg(feature = "local_oauth")))] - pub fn with_local_authorization(mut self, redirect_uri: impl Into) -> Self { + pub fn with_local_authorization( + mut self, + redirect_uri: impl Into, + scopes: Scopes, + ) -> Self { self.auth = Some(AuthorizationBuilder::LocalOauth { redirect_uri: redirect_uri.into(), + scopes, }); self @@ -180,10 +188,12 @@ impl OsuBuilder { mut self, code: impl Into, redirect_uri: impl Into, + scopes: Scopes, ) -> Self { let authorization = Authorization { code: code.into(), redirect_uri: redirect_uri.into(), + scopes, }; self.auth = Some(AuthorizationBuilder::Kind(AuthorizationKind::User( diff --git a/src/client/mod.rs b/src/client/mod.rs index 7147b18..ea77052 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,11 +1,12 @@ mod builder; +mod scopes; mod token; use bytes::Bytes; use token::{Authorization, AuthorizationKind, Token, TokenResponse}; pub use builder::OsuBuilder; -pub use token::Scope; +pub use scopes::Scopes; #[allow(clippy::wildcard_imports)] use crate::{error::OsuError, model::GameMode, request::*, OsuResult}; @@ -600,7 +601,7 @@ impl Drop for Osu { pub(crate) struct OsuRef { client_id: u64, - client_secret: String, + client_secret: Box, http: HyperClient, BodyBytes>, timeout: Duration, ratelimiter: LeakyBucket, @@ -626,24 +627,26 @@ impl OsuRef { body.push_int("client_id", self.client_id); body.push_str("client_secret", &self.client_secret); - match &self.auth_kind { - AuthorizationKind::Client(scope) => { + match self.auth_kind { + AuthorizationKind::Client => { body.push_str("grant_type", "client_credentials"); - body.push_str("scope", scope.as_str()); + let mut scopes = String::new(); + Scopes::Public.format(&mut scopes, ' '); + body.push_str("scope", &scopes); } - AuthorizationKind::User(auth) => match &self.token.read().await.refresh { - Some(refresh) => { + AuthorizationKind::User(ref auth) => { + if let Some(ref refresh) = self.token.read().await.refresh { body.push_str("grant_type", "refresh_token"); body.push_str("refresh_token", refresh); - } - None => { + } else { body.push_str("grant_type", "authorization_code"); body.push_str("redirect_uri", &auth.redirect_uri); body.push_str("code", &auth.code); - // FIXME: let users decide which scopes to use? - body.push_str("scope", "identify public"); + let mut scopes = String::new(); + auth.scopes.format(&mut scopes, ' '); + body.push_str("scope", &scopes); } - }, + } }; let bytes = BodyBytes::from(body); diff --git a/src/client/scopes.rs b/src/client/scopes.rs new file mode 100644 index 0000000..a94367e --- /dev/null +++ b/src/client/scopes.rs @@ -0,0 +1,99 @@ +use std::ops::{BitOr, BitOrAssign}; + +/// Scopes bitflags for an [`Osu`] client. +/// +/// To specify multiple scopes, create a union using the `|` operator. +/// +/// See . +/// +/// [`Osu`]: crate::Osu +#[derive(Copy, Clone, PartialEq, Eq)] +pub struct Scopes(u16); + +macro_rules! define_scopes { + ( $( + #[doc = $doc:literal] + $scope:ident: $shift:literal, $str:literal; + )* ) => { + define_scopes! {@ $( + #[doc = $doc] + $scope: 1 << $shift, $str; + )* } + }; + (@ $( + #[doc = $doc:literal] + $scope:ident: $bit:expr, $str:literal; + )* ) => { + $( + #[allow(non_upper_case_globals)] + impl Scopes { + #[doc = $doc] + pub const $scope: Self = Self($bit); + } + )* + + impl Scopes { + const fn contains(self, bit: u16) -> bool { + (self.0 & bit) > 0 + } + + pub(crate) fn format(self, s: &mut String, separator: char) { + let mut first_scope = true; + + $( + if self.contains($bit) { + if !first_scope { + s.push(separator); + } + + s.push_str($str); + + #[allow(unused_assignments)] + { + first_scope = false; + } + } + )* + } + } + }; +} + +define_scopes! { + /// Allows reading chat messages on a user's behalf. + ChatRead: 0, "chat.read"; + /// Allows sending chat messages on a user's behalf. + ChatWrite: 1, "chat.write"; + /// Allows joining and leaving chat channels on a user's behalf. + ChatWriteManage: 2, "chat.write_manage"; + /// Allows acting as the owner of a client. + Delegate: 3, "delegate"; + /// Allows creating and editing forum posts on a user's behalf. + ForumWrite: 4, "forum.write"; + /// Allows reading of the user's friend list. + FriendsRead: 5, "friends.read"; + /// Allows reading of the public profile of the user. + Identify: 6, "identify"; + /// Allows reading of publicly available data on behalf of the user. + Public: 7, "public"; +} + +impl Default for Scopes { + fn default() -> Self { + Self::Public + } +} + +impl BitOr for Scopes { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self::Output { + Self(self.0.bitor(rhs.0)) + } +} + +impl BitOrAssign for Scopes { + fn bitor_assign(&mut self, rhs: Self) { + self.0.bitor_assign(rhs.0); + } +} diff --git a/src/client/token.rs b/src/client/token.rs index f108355..d1c4334 100644 --- a/src/client/token.rs +++ b/src/client/token.rs @@ -1,7 +1,6 @@ -use super::OsuRef; +use super::{OsuRef, Scopes}; use serde::Deserialize; -use std::fmt::{Display, Formatter, Result as FmtResult}; use std::{error::Error, sync::Arc, time::Duration}; use tokio::{ sync::oneshot::{self, Receiver}, @@ -115,6 +114,7 @@ pub(super) enum AuthorizationBuilder { #[cfg(feature = "local_oauth")] LocalOauth { redirect_uri: String, + scopes: Scopes, }, } @@ -123,9 +123,9 @@ impl AuthorizationBuilder { pub(super) async fn perform_local_oauth( redirect_uri: String, client_id: u64, + scopes: Scopes, ) -> Result { use std::{ - fmt::Write, io::{Error as IoError, ErrorKind}, str::from_utf8 as str_from_utf8, }; @@ -159,17 +159,9 @@ impl AuthorizationBuilder { &response_type=code", ); - let mut scopes = [Scope::Identify, Scope::Public].iter(); - - if let Some(scope) = scopes.next() { - let _ = write!(url, "&scopes=%22{scope}"); - - for scope in scopes { - let _ = write!(url, "+{scope}"); - } - - url.push_str("%22"); - } + url.push_str("&scopes=%22"); + scopes.format(&mut url, '+'); + url.push_str("%22"); println!("Authorize yourself through the following url:\n{url}"); info!("Awaiting manual authorization..."); @@ -230,24 +222,29 @@ You may close this tab respond(&mut stream).await.map_err(OAuthError::Write)?; - Ok(Authorization { code, redirect_uri }) + Ok(Authorization { + code, + redirect_uri, + scopes, + }) } } pub(super) enum AuthorizationKind { User(Authorization), - Client(Scope), + Client, } impl Default for AuthorizationKind { fn default() -> Self { - Self::Client(Scope::Public) + Self::Client } } pub(super) struct Authorization { pub code: String, pub redirect_uri: String, + pub scopes: Scopes, } #[derive(Deserialize)] @@ -258,35 +255,3 @@ pub(super) struct TokenResponse { pub refresh_token: Option, pub token_type: String, } - -#[derive(Copy, Clone, Eq, PartialEq)] -#[non_exhaustive] -pub enum Scope { - ChatWrite, - Delegate, - ForumWrite, - FriendsRead, - Identify, - Lazer, - Public, -} - -impl Scope { - pub const fn as_str(self) -> &'static str { - match self { - Scope::ChatWrite => "chat.write", - Scope::Delegate => "delegate", - Scope::ForumWrite => "forum.write", - Scope::FriendsRead => "friends.read", - Scope::Identify => "identify", - Scope::Lazer => "lazer", - Scope::Public => "public", - } - } -} - -impl Display for Scope { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - f.write_str(self.as_str()) - } -} diff --git a/src/lib.rs b/src/lib.rs index 6e5c18b..c463799 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -168,7 +168,7 @@ pub type OsuResult = Result; /// All types except requesting, stuffed into one module pub mod prelude { pub use crate::{ - client::Scope, + client::Scopes, error::OsuError, model::{ beatmap::*,