diff --git a/DESCRIPTION b/DESCRIPTION index 461fc07f..4cdcc37b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -42,7 +42,7 @@ Imports: utils, yaml Suggests: - AzureRMR, + AzureGraph, future, grDevices, knitr, diff --git a/NAMESPACE b/NAMESPACE index 46fd2c12..9eaaa51c 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -41,6 +41,7 @@ export(create_completion_huggingface) export(get_available_endpoints) export(get_available_models) export(get_ide_theme_info) +export(gptstudio_cache_directory) export(gptstudio_chat) export(gptstudio_chat_in_source_addin) export(gptstudio_comment_code) diff --git a/R/cache.R b/R/cache.R new file mode 100644 index 00000000..8d367584 --- /dev/null +++ b/R/cache.R @@ -0,0 +1,5 @@ +#' a function that determines the appropriate directory to cache a token +#' @export +gptstudio_cache_directory <- function() { + tools::R_user_dir(package = "gptstudio") +} diff --git a/R/service-azure_openai.R b/R/service-azure_openai.R index a24dce46..035d9258 100644 --- a/R/service-azure_openai.R +++ b/R/service-azure_openai.R @@ -108,33 +108,54 @@ query_api_azure_openai <- } retrieve_azure_token <- function() { - rlang::check_installed("AzureRMR") - - token <- tryCatch( - { - AzureRMR::get_azure_login( - tenant = Sys.getenv("AZURE_OPENAI_TENANT_ID"), - app = Sys.getenv("AZURE_OPENAI_CLIENT_ID"), - scopes = ".default" - ) - }, - error = function(e) NULL - ) - if (is.null(token)) { - token <- AzureRMR::create_azure_login( - tenant = Sys.getenv("AZURE_OPENAI_TENANT_ID"), - app = Sys.getenv("AZURE_OPENAI_CLIENT_ID"), - password = Sys.getenv("AZURE_OPENAI_CLIENT_SECRET"), - host = "https://cognitiveservices.azure.com/", - scopes = ".default" - ) + token <- retrieve_azure_token_object() |> suppressMessages() + + invisible(token$credentials$access_token) +} + + +retrieve_azure_token_object <- function() { + rlang::check_installed("AzureGraph") + + ## Set this so that get_graph_login properly caches + azure_data_env <- Sys.getenv("R_AZURE_DATA_DIR") + + Sys.setenv("R_AZURE_DATA_DIR" = gptstudio_cache_directory()) + + login <- try(AzureGraph::get_graph_login(tenant = Sys.getenv("AZURE_OPENAI_TENANT_ID"), + app = Sys.getenv("AZURE_OPENAI_CLIENT_ID"), + scopes = NULL, + refresh = FALSE), + silent = TRUE) |> + suppressMessages() + + if (inherits(login, "try-error")) { + + if (!dir.exists(gptstudio_cache_directory())) { + dir.create(gptstudio_cache_directory()) |> + suppressWarnings() + } + + + login <- AzureGraph::create_graph_login(tenant = Sys.getenv("AZURE_OPENAI_TENANT_ID"), + app = Sys.getenv("AZURE_OPENAI_CLIENT_ID"), + host = Sys.getenv("AZURE_OPENAI_SCOPE"), + scopes = NULL, + auth_type = "client_credentials", + password = Sys.getenv("AZURE_OPENAI_CLIENT_SECRET")) |> + suppressMessages() } - invisible(token$token$credentials$access_token) + ## Set this so that get_graph_login properly caches + Sys.setenv("R_AZURE_DATA_DIR" = azure_data_env) + + invisible(login$token) } + + stream_azure_openai <- function(messages = list(list(role = "user", content = "hi there")), element_callback = cat) { body <- list( @@ -155,6 +176,5 @@ stream_azure_openai <- function(messages = list(list(role = "user", content = "h }, round = "line" ) - invisible(response) } diff --git a/man/gptstudio_cache_directory.Rd b/man/gptstudio_cache_directory.Rd new file mode 100644 index 00000000..6d8b4ccf --- /dev/null +++ b/man/gptstudio_cache_directory.Rd @@ -0,0 +1,11 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/cache.R +\name{gptstudio_cache_directory} +\alias{gptstudio_cache_directory} +\title{a function that determines the appropriate directory to cache a token} +\usage{ +gptstudio_cache_directory() +} +\description{ +a function that determines the appropriate directory to cache a token +} diff --git a/tests/testthat/test-service-azure_openai.R b/tests/testthat/test-service-azure_openai.R index c78192f9..cd5dd038 100644 --- a/tests/testthat/test-service-azure_openai.R +++ b/tests/testthat/test-service-azure_openai.R @@ -153,12 +153,14 @@ test_that("query_api_azure_openai handles error response", { # Test token retrieval -------------------------------------------------------- test_that("retrieve_azure_token successfully gets existing token", { + skip_on_ci() + local_mocked_bindings( - get_azure_login = function(...) { - list(token = list(credentials = list(access_token = "existing_token"))) + get_graph_login = function(...) { + list(credentials = list(access_token = "existing_token")) }, - create_azure_login = function(...) stop("Should not be called"), - .package = "AzureRMR" + create_graph_login = function(...) stop("Should not be called"), + .package = "AzureGraph" ) token <- retrieve_azure_token() @@ -166,13 +168,15 @@ test_that("retrieve_azure_token successfully gets existing token", { expect_equal(token, "existing_token") }) -test_that("retrieve_azure_token creates new token when get_azure_login fails", { +test_that("retrieve_azure_token creates new token when get_graph_login fails", { + skip_on_ci() + local_mocked_bindings( - get_azure_login = function(...) stop("Error"), - create_azure_login = function(...) { - list(token = list(credentials = list(access_token = "new_token"))) + get_graph_login = function(...) stop("Error"), + create_graph_login = function(...) { + list(credentials = list(access_token = "new_token")) }, - .package = "AzureRMR" + .package = "AzureGraph" ) token <- retrieve_azure_token() @@ -180,41 +184,50 @@ test_that("retrieve_azure_token creates new token when get_azure_login fails", { expect_equal(token, "new_token") }) + test_that("retrieve_azure_token uses correct environment variables", { - mock_get_azure_login <- function(tenant, app, scopes) { + skip_on_ci() + + mock_get_graph_login <- function(tenant, app, scopes, refresh) { expect_equal(tenant, "test_tenant") expect_equal(app, "test_client") - expect_equal(scopes, ".default") + expect_equal(scopes, NULL) + expect_equal(refresh, FALSE) stop("Error") } - mock_create_azure_login <- function(tenant, app, password, host, scopes) { + mock_create_graph_login <- function(tenant, app, host, scopes, auth_type, password) { expect_equal(tenant, "test_tenant") expect_equal(app, "test_client") + expect_equal(host, "https://cognitiveservices.azure.com/.default") + expect_equal(scopes, NULL) + expect_equal(auth_type, "client_credentials") expect_equal(password, "test_secret") - expect_equal(host, "https://cognitiveservices.azure.com/") - expect_equal(scopes, ".default") - list(token = list(credentials = list(access_token = "new_token"))) + list(credentials = list(access_token = "new_token")) } local_mocked_bindings( - get_azure_login = mock_get_azure_login, - create_azure_login = mock_create_azure_login, - .package = "AzureRMR" + get_graph_login = mock_get_graph_login, + create_graph_login = mock_create_graph_login, + .package = "AzureGraph" ) withr::local_envvar( AZURE_OPENAI_TENANT_ID = "test_tenant", AZURE_OPENAI_CLIENT_ID = "test_client", - AZURE_OPENAI_CLIENT_SECRET = "test_secret" + AZURE_OPENAI_CLIENT_SECRET = "test_secret", + AZURE_OPENAI_SCOPE = "https://cognitiveservices.azure.com/.default" ) expect_no_error(retrieve_azure_token()) }) -test_that("retrieve_azure_token checks for AzureRMR installation", { + + + +test_that("retrieve_azure_token checks for AzureGraph installation", { mock_check_installed <- function(pkg) { - expect_equal(pkg, "AzureRMR") + expect_equal(pkg, "AzureGraph") } local_mocked_bindings(