Skip to content

Commit

Permalink
feat!: allow users to specify scopes (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxOhn authored Oct 19, 2024
1 parent 1f959f1 commit dad87b8
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 70 deletions.
26 changes: 18 additions & 8 deletions src/client/builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use super::{token::AuthorizationBuilder, Authorization, AuthorizationKind, Osu, OsuRef, Token};
use super::{
token::AuthorizationBuilder, Authorization, AuthorizationKind, Osu, OsuRef, Scopes, Token,
};
use crate::{error::OsuError, OsuResult};

use hyper::client::Builder;
Expand Down Expand Up @@ -63,11 +65,12 @@ impl OsuBuilder {
let auth_kind = match self.auth {
Some(AuthorizationBuilder::Kind(kind)) => kind,
#[cfg(feature = "local_oauth")]
Some(AuthorizationBuilder::LocalOauth { redirect_uri }) => {
AuthorizationBuilder::perform_local_oauth(redirect_uri, client_id)
.await
.map(AuthorizationKind::User)?
}
Some(AuthorizationBuilder::LocalOauth {
redirect_uri,
scopes,
}) => AuthorizationBuilder::perform_local_oauth(redirect_uri, client_id, scopes)
.await
.map(AuthorizationKind::User)?,
None => AuthorizationKind::default(),
};

Expand All @@ -91,7 +94,7 @@ impl OsuBuilder {

let inner = Arc::new(OsuRef {
client_id,
client_secret,
client_secret: client_secret.into_boxed_str(),
http,
ratelimiter,
timeout: self.timeout,
Expand Down Expand Up @@ -159,9 +162,14 @@ impl OsuBuilder {
/// [`with_authorization`]: OsuBuilder::with_authorization
#[cfg(feature = "local_oauth")]
#[cfg_attr(docsrs, doc(cfg(feature = "local_oauth")))]
pub fn with_local_authorization(mut self, redirect_uri: impl Into<String>) -> Self {
pub fn with_local_authorization(
mut self,
redirect_uri: impl Into<String>,
scopes: Scopes,
) -> Self {
self.auth = Some(AuthorizationBuilder::LocalOauth {
redirect_uri: redirect_uri.into(),
scopes,
});

self
Expand All @@ -180,10 +188,12 @@ impl OsuBuilder {
mut self,
code: impl Into<String>,
redirect_uri: impl Into<String>,
scopes: Scopes,
) -> Self {
let authorization = Authorization {
code: code.into(),
redirect_uri: redirect_uri.into(),
scopes,
};

self.auth = Some(AuthorizationBuilder::Kind(AuthorizationKind::User(
Expand Down
27 changes: 15 additions & 12 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
mod builder;
mod scopes;
mod token;

use bytes::Bytes;
use token::{Authorization, AuthorizationKind, Token, TokenResponse};

pub use builder::OsuBuilder;
pub use token::Scope;
pub use scopes::Scopes;

#[allow(clippy::wildcard_imports)]
use crate::{error::OsuError, model::GameMode, request::*, OsuResult};
Expand Down Expand Up @@ -600,7 +601,7 @@ impl Drop for Osu {

pub(crate) struct OsuRef {
client_id: u64,
client_secret: String,
client_secret: Box<str>,
http: HyperClient<HttpsConnector<HttpConnector>, BodyBytes>,
timeout: Duration,
ratelimiter: LeakyBucket,
Expand All @@ -626,24 +627,26 @@ impl OsuRef {
body.push_int("client_id", self.client_id);
body.push_str("client_secret", &self.client_secret);

match &self.auth_kind {
AuthorizationKind::Client(scope) => {
match self.auth_kind {
AuthorizationKind::Client => {
body.push_str("grant_type", "client_credentials");
body.push_str("scope", scope.as_str());
let mut scopes = String::new();
Scopes::Public.format(&mut scopes, ' ');
body.push_str("scope", &scopes);
}
AuthorizationKind::User(auth) => match &self.token.read().await.refresh {
Some(refresh) => {
AuthorizationKind::User(ref auth) => {
if let Some(ref refresh) = self.token.read().await.refresh {
body.push_str("grant_type", "refresh_token");
body.push_str("refresh_token", refresh);
}
None => {
} else {
body.push_str("grant_type", "authorization_code");
body.push_str("redirect_uri", &auth.redirect_uri);
body.push_str("code", &auth.code);
// FIXME: let users decide which scopes to use?
body.push_str("scope", "identify public");
let mut scopes = String::new();
auth.scopes.format(&mut scopes, ' ');
body.push_str("scope", &scopes);
}
},
}
};

let bytes = BodyBytes::from(body);
Expand Down
99 changes: 99 additions & 0 deletions src/client/scopes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use std::ops::{BitOr, BitOrAssign};

/// Scopes bitflags for an [`Osu`] client.
///
/// To specify multiple scopes, create a union using the `|` operator.
///
/// See <https://osu.ppy.sh/docs/index.html#scopes>.
///
/// [`Osu`]: crate::Osu
#[derive(Copy, Clone, PartialEq, Eq)]
pub struct Scopes(u16);

macro_rules! define_scopes {
( $(
#[doc = $doc:literal]
$scope:ident: $shift:literal, $str:literal;
)* ) => {
define_scopes! {@ $(
#[doc = $doc]
$scope: 1 << $shift, $str;
)* }
};
(@ $(
#[doc = $doc:literal]
$scope:ident: $bit:expr, $str:literal;
)* ) => {
$(
#[allow(non_upper_case_globals)]
impl Scopes {
#[doc = $doc]
pub const $scope: Self = Self($bit);
}
)*

impl Scopes {
const fn contains(self, bit: u16) -> bool {
(self.0 & bit) > 0
}

pub(crate) fn format(self, s: &mut String, separator: char) {
let mut first_scope = true;

$(
if self.contains($bit) {
if !first_scope {
s.push(separator);
}

s.push_str($str);

#[allow(unused_assignments)]
{
first_scope = false;
}
}
)*
}
}
};
}

define_scopes! {
/// Allows reading chat messages on a user's behalf.
ChatRead: 0, "chat.read";
/// Allows sending chat messages on a user's behalf.
ChatWrite: 1, "chat.write";
/// Allows joining and leaving chat channels on a user's behalf.
ChatWriteManage: 2, "chat.write_manage";
/// Allows acting as the owner of a client.
Delegate: 3, "delegate";
/// Allows creating and editing forum posts on a user's behalf.
ForumWrite: 4, "forum.write";
/// Allows reading of the user's friend list.
FriendsRead: 5, "friends.read";
/// Allows reading of the public profile of the user.
Identify: 6, "identify";
/// Allows reading of publicly available data on behalf of the user.
Public: 7, "public";
}

impl Default for Scopes {
fn default() -> Self {
Self::Public
}
}

impl BitOr for Scopes {
type Output = Self;

fn bitor(self, rhs: Self) -> Self::Output {
Self(self.0.bitor(rhs.0))
}
}

impl BitOrAssign for Scopes {
fn bitor_assign(&mut self, rhs: Self) {
self.0.bitor_assign(rhs.0);
}
}
63 changes: 14 additions & 49 deletions src/client/token.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use super::OsuRef;
use super::{OsuRef, Scopes};

use serde::Deserialize;
use std::fmt::{Display, Formatter, Result as FmtResult};
use std::{error::Error, sync::Arc, time::Duration};
use tokio::{
sync::oneshot::{self, Receiver},
Expand Down Expand Up @@ -115,6 +114,7 @@ pub(super) enum AuthorizationBuilder {
#[cfg(feature = "local_oauth")]
LocalOauth {
redirect_uri: String,
scopes: Scopes,
},
}

Expand All @@ -123,9 +123,9 @@ impl AuthorizationBuilder {
pub(super) async fn perform_local_oauth(
redirect_uri: String,
client_id: u64,
scopes: Scopes,
) -> Result<Authorization, crate::error::OAuthError> {
use std::{
fmt::Write,
io::{Error as IoError, ErrorKind},
str::from_utf8 as str_from_utf8,
};
Expand Down Expand Up @@ -159,17 +159,9 @@ impl AuthorizationBuilder {
&response_type=code",
);

let mut scopes = [Scope::Identify, Scope::Public].iter();

if let Some(scope) = scopes.next() {
let _ = write!(url, "&scopes=%22{scope}");

for scope in scopes {
let _ = write!(url, "+{scope}");
}

url.push_str("%22");
}
url.push_str("&scopes=%22");
scopes.format(&mut url, '+');
url.push_str("%22");

println!("Authorize yourself through the following url:\n{url}");
info!("Awaiting manual authorization...");
Expand Down Expand Up @@ -230,24 +222,29 @@ You may close this tab

respond(&mut stream).await.map_err(OAuthError::Write)?;

Ok(Authorization { code, redirect_uri })
Ok(Authorization {
code,
redirect_uri,
scopes,
})
}
}

pub(super) enum AuthorizationKind {
User(Authorization),
Client(Scope),
Client,
}

impl Default for AuthorizationKind {
fn default() -> Self {
Self::Client(Scope::Public)
Self::Client
}
}

pub(super) struct Authorization {
pub code: String,
pub redirect_uri: String,
pub scopes: Scopes,
}

#[derive(Deserialize)]
Expand All @@ -258,35 +255,3 @@ pub(super) struct TokenResponse {
pub refresh_token: Option<String>,
pub token_type: String,
}

#[derive(Copy, Clone, Eq, PartialEq)]
#[non_exhaustive]
pub enum Scope {
ChatWrite,
Delegate,
ForumWrite,
FriendsRead,
Identify,
Lazer,
Public,
}

impl Scope {
pub const fn as_str(self) -> &'static str {
match self {
Scope::ChatWrite => "chat.write",
Scope::Delegate => "delegate",
Scope::ForumWrite => "forum.write",
Scope::FriendsRead => "friends.read",
Scope::Identify => "identify",
Scope::Lazer => "lazer",
Scope::Public => "public",
}
}
}

impl Display for Scope {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
f.write_str(self.as_str())
}
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ pub type OsuResult<T> = Result<T, error::OsuError>;
/// All types except requesting, stuffed into one module
pub mod prelude {
pub use crate::{
client::Scope,
client::Scopes,
error::OsuError,
model::{
beatmap::*,
Expand Down

0 comments on commit dad87b8

Please sign in to comment.