Skip to content

Commit

Permalink
* #1 Support for non GUI users
Browse files Browse the repository at this point in the history
  • Loading branch information
z-Wind committed Jul 15, 2024
1 parent 7ffe4af commit 3e35da1
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 28 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ tokio = { version = "1", features = [
reqwest = { version = "0.11", features = ["blocking", "json"] }
dirs = "5.0"
url = "2.5"
http = "1.1"
axum = { version = "0.7", features = ["macros"] }
axum-server = { version = "0.6", features = ["tls-rustls"] }
thiserror = "1.0"
Expand Down
52 changes: 51 additions & 1 deletion src/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,35 @@ impl TokenChecker {
redirect_url: String,
certs_dir: PathBuf,
) -> Result<Self, Error> {
let auth = Authorizer::new(client_id, secret, redirect_url, certs_dir);
let auth = Authorizer::new(
client_id,
secret,
redirect_url,
auth::AuthProcess::Auto { certs_dir },
);
let token = match Token::load(path.clone()) {
Ok(token) => token,
Err(_) => auth.save(path.clone()).await?,
};

let checker = Self {
path,
authorizer: auth,
token: Mutex::new(token),
};

checker.check_or_update().await?;

Ok(checker)
}

pub async fn new_with_auth_manually(
path: PathBuf,
client_id: String,
secret: String,
redirect_url: String,
) -> Result<Self, Error> {
let auth = Authorizer::new(client_id, secret, redirect_url, auth::AuthProcess::Manual);
let token = match Token::load(path.clone()) {
Ok(token) => token,
Err(_) => auth.save(path.clone()).await?,
Expand Down Expand Up @@ -170,6 +198,28 @@ mod tests {
.unwrap();
}

#[tokio::test]
#[ignore = "Testing manually for browser verification. Should be --nocapture"]
async fn test_token_checker_new_with_auth_manually() {
let path = dirs::home_dir()
.expect("home dir")
.join(".credentials")
.join("Schwab-rust.json");
#[allow(clippy::option_env_unwrap)]
let client_id = option_env!("SCHWAB_API_KEY").expect("There should be SCHWAB API KEY");
#[allow(clippy::option_env_unwrap)]
let secret = option_env!("SCHWAB_SECRET").expect("There should be SCHWAB SECRET");

TokenChecker::new_with_auth_manually(
path,
client_id.to_string(),
secret.to_string(),
"https://127.0.0.1:8080".to_string(),
)
.await
.unwrap();
}

#[test]
fn test_save_token() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
Expand Down
135 changes: 117 additions & 18 deletions src/token/auth.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use axum::extract::Query;
use http::uri::Uri;
use oauth2::basic::BasicClient;
use oauth2::reqwest::async_http_client;
use oauth2::{
Expand All @@ -6,6 +8,7 @@ use oauth2::{
TokenUrl,
};
use oauth2::{ClientSecret, Scope};
use serde::Deserialize;
use std::path::PathBuf;
use url::Url;

Expand All @@ -15,18 +18,30 @@ use crate::token::Token;

type RequestTokenError = BasicRequestTokenError<oauth2::reqwest::Error<reqwest::Error>>;

#[derive(Debug)]
pub(super) enum AuthProcess {
Auto { certs_dir: PathBuf },
Manual,
}

#[derive(Debug, Deserialize)]
pub(super) struct AuthRequest {
pub(super) code: String,
pub(super) state: String,
}

#[derive(Debug)]
pub(super) struct Authorizer {
client: BasicClient,
certs_dir: PathBuf,
process: AuthProcess,
}

impl Authorizer {
pub(super) fn new(
app_key: String,
secret: String,
redirect_url: String,
certs_dir: PathBuf,
process: AuthProcess,
) -> Self {
let app_key = ClientId::new(app_key);
let secret = ClientSecret::new(secret);
Expand All @@ -38,21 +53,25 @@ impl Authorizer {

let client = BasicClient::new(app_key, Some(secret), auth_url, Some(token_url))
.set_redirect_uri(redirect_url);
Authorizer { client, certs_dir }
Authorizer { client, process }
}

async fn authorize(&self) -> Result<Token, RequestTokenError> {
let (auth_url, csrf_token) = self.auth_code_url();

match open::that(auth_url.as_ref()) {
Ok(()) => println!("Opened '{auth_url}' successfully."),
Err(err) => {
print!("An error occurred when opening '{auth_url}': {err}");
println!("Please Open this URL in your browser manually\n{auth_url}",);
}
}

let auth_code = Self::auth_code(csrf_token, self.certs_dir.clone()).await;
let auth_code = match &self.process {
AuthProcess::Auto { certs_dir } => match open::that(auth_url.as_ref()) {
Ok(()) => {
println!("Opened '{auth_url}' successfully.");
Self::get_auth_code_with_local_server(csrf_token, certs_dir.clone()).await
}
Err(err) => {
print!("An error occurred when auto opening: {err}");
Self::get_auth_code_manually(&csrf_token, &auth_url)
}
},
AuthProcess::Manual => Self::get_auth_code_manually(&csrf_token, &auth_url),
};

let token_result = self.refresh_token(auth_code).await?;
// dbg!(&token_result);
Expand Down Expand Up @@ -84,12 +103,60 @@ impl Authorizer {
(auth_url, csrf_token)
}

async fn auth_code(csrf_state: CsrfToken, certs_dir: PathBuf) -> AuthorizationCode {
async fn get_auth_code_with_local_server(
csrf_state: CsrfToken,
certs_dir: PathBuf,
) -> AuthorizationCode {
let code = local_server::local_server(csrf_state, certs_dir).await;

AuthorizationCode::new(code)
}

fn get_auth_code_manually(csrf: &CsrfToken, auth_url: &Url) -> AuthorizationCode {
println!(
r#"
**************************************************************
This is the manual login and token creation flow for schwab_api.
Please follow these instructions exactly:
1. Open the following link by copy-pasting it into the browser
of your choice:
{auth_url}
2. Log in with your account credentials. You may be asked to
perform two-factor authentication using text messaging or
another method, as well as whether to trust the browser.
3. When asked whether to allow your app access to your account,
select "Allow".
4. Your browser should be redirected to your callback URI. Copy
the ENTIRE address, paste it into the following prompt, and press
Enter/Return.
**************************************************************
Redirect URL>"#
);

let mut input = String::new();
std::io::stdin()
.read_line(&mut input)
.unwrap_or_else(|err| panic!("error: {err}"));

let uri: Uri = input.trim().parse().expect("right uri");
Self::uri_to_auth_code(&uri, csrf)
}

fn uri_to_auth_code(uri: &Uri, csrf: &CsrfToken) -> AuthorizationCode {
let Query(query): Query<AuthRequest> = Query::try_from_uri(uri).expect("right format");
assert!(&query.state == csrf.secret(), "CSRF check error");

AuthorizationCode::new(query.code)
}

async fn refresh_token(
&self,
auth_code: AuthorizationCode,
Expand Down Expand Up @@ -141,12 +208,32 @@ mod tests {

#[tokio::test]
#[ignore = "Testing manually for browser verification. Should be --nocapture"]
async fn test_auth() {
async fn test_auth_auto() {
let auth = Authorizer::new(
client_id_static().to_string(),
secret_static().to_string(),
REDIRECT_URL.to_string(),
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/certs"),
AuthProcess::Auto {
certs_dir: PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/certs"),
},
);

let token = auth.authorize().await.unwrap();
dbg!(&token);

// test refresh access token
let access_token = auth.access_token(&token.refresh).await.unwrap();
dbg!(&access_token);
}

#[tokio::test]
#[ignore = "Testing manually for browser verification. Should be --nocapture"]
async fn test_auth_manually() {
let auth = Authorizer::new(
client_id_static().to_string(),
secret_static().to_string(),
REDIRECT_URL.to_string(),
AuthProcess::Manual,
);

let token = auth.authorize().await.unwrap();
Expand All @@ -165,7 +252,9 @@ mod tests {
CLIENTID.to_string(),
SECRET.to_string(),
REDIRECT_URL.to_string(),
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/certs"),
AuthProcess::Auto {
certs_dir: PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/certs"),
},
);

let (auth_url, csrf_token) = auth.auth_code_url();
Expand Down Expand Up @@ -201,8 +290,8 @@ mod tests {

#[tokio::test]
#[ignore = "If the test is performed manually on Linux, it may fail for HTTPS."]
async fn test_get_auth_code() {
let auth_code = tokio::spawn(Authorizer::auth_code(
async fn test_get_auth_code_with_local_server() {
let auth_code = tokio::spawn(Authorizer::get_auth_code_with_local_server(
CsrfToken::new("CSRF".to_string()),
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/certs"),
));
Expand All @@ -224,4 +313,14 @@ mod tests {
assert_eq!(auth_code.await.unwrap().secret(), "code");
assert_eq!(body, "Schwab returned the following code:\ncode\nYou can now safely close this browser window.");
}

#[test]
fn test_uri_to_auth_code() {
let csrf = CsrfToken::new("CSRF".to_string());
let uri: Uri = format!("https://127.0.0.1:8080/?state={}&code=code", csrf.secret())
.parse()
.unwrap();
let auth_code = Authorizer::uri_to_auth_code(&uri, &csrf);
assert_eq!(auth_code.secret(), "code");
}
}
10 changes: 1 addition & 9 deletions src/token/local_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use axum::{
};
use axum_server::tls_rustls::RustlsConfig;
use oauth2::CsrfToken;
use serde::Deserialize;
use std::net::SocketAddr;
use std::path::PathBuf;

Expand Down Expand Up @@ -39,15 +38,8 @@ struct AppState {
tx: async_channel::Sender<String>,
}

#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct AuthRequest {
code: String,
state: String,
}

async fn get_code(
Query(query): Query<AuthRequest>,
Query(query): Query<super::auth::AuthRequest>,
State(csrf): State<CsrfToken>,
State(tx): State<async_channel::Sender<String>>,
) -> impl IntoResponse {
Expand Down

0 comments on commit 3e35da1

Please sign in to comment.