Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(google): added gemini api as provider #79

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added bun.lockb
Binary file not shown.
4 changes: 3 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"@anthropic-ai/sdk": "^0.27.3",
"@commitlint/cli": "^19.5.0",
"@commitlint/config-conventional": "^19.5.0",
"@google/generative-ai": "^0.21.0",
"@ianvs/prettier-plugin-sort-imports": "^4.2.1",
"@release-it/conventional-changelog": "^8.0.2",
"@typescript-eslint/eslint-plugin": "^7.3.1",
Expand Down Expand Up @@ -60,5 +61,6 @@
}
],
"license": "MIT",
"author": "Arshad Yaseen <m@arshadyaseen.com> (https://arshadyaseen.com)"
"author": "Arshad Yaseen <m@arshadyaseen.com> (https://arshadyaseen.com)",
"dependencies": {}
}
8 changes: 6 additions & 2 deletions src/classes/copilot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ import {
DEFAULT_COPILOT_PROVIDER,
} from '../constants';
import {
createProviderEndpoint,
createProviderHeaders,
createRequestBody,
getCopilotProviderEndpoint,
parseProviderChatCompletion,
} from '../helpers/provider';
import {deprecated, report} from '../logger';
Expand Down Expand Up @@ -106,7 +106,11 @@ export class Copilot {
requestBody: ChatCompletionCreateParams;
headers: Record<string, string>;
} {
let endpoint = getCopilotProviderEndpoint(this.provider);
let endpoint = createProviderEndpoint(
this.model as CopilotModel,
this.apiKey,
this.provider,
);
let requestBody: ChatCompletionCreateParams;
let headers = createProviderHeaders(this.apiKey, this.provider);

Expand Down
12 changes: 11 additions & 1 deletion src/constants/copilot.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import {CopilotModel, CopilotProvider} from '../types';

export const COPILOT_PROVIDERS = ['groq', 'openai', 'anthropic'] as const;
export const COPILOT_PROVIDERS = [
'groq',
'openai',
'anthropic',
'google',
] as const;

export const COPILOT_MODEL_IDS: Record<CopilotModel, string> = {
'llama-3-70b': 'llama3-70b-8192',
Expand All @@ -11,6 +16,9 @@ export const COPILOT_MODEL_IDS: Record<CopilotModel, string> = {
'claude-3-5-haiku': 'claude-3-5-haiku-20241022',
'o1-preview': 'o1-preview',
'o1-mini': 'o1-mini',
'gemini-1.5-flash-8b': 'gemini-1.5-flash-8b',
'gemini-1.5-flash': 'gemini-1.5-flash',
'gemini-1.0-pro': 'gemini-1.0-pro',
} as const;

export const COPILOT_PROVIDER_MODEL_MAP: Record<
Expand All @@ -20,6 +28,7 @@ export const COPILOT_PROVIDER_MODEL_MAP: Record<
groq: ['llama-3-70b'],
openai: ['gpt-4o', 'gpt-4o-mini', 'o1-preview', 'o1-mini'],
anthropic: ['claude-3-5-sonnet', 'claude-3-haiku', 'claude-3-5-haiku'],
google: ['gemini-1.5-flash-8b', 'gemini-1.0-pro', 'gemini-1.5-flash'],
} as const;

export const DEFAULT_COPILOT_PROVIDER: CopilotProvider = 'anthropic' as const;
Expand All @@ -29,6 +38,7 @@ export const COPILOT_PROVIDER_ENDPOINT_MAP: Record<CopilotProvider, string> = {
groq: 'https://api.groq.com/openai/v1/chat/completions',
openai: 'https://api.openai.com/v1/chat/completions',
anthropic: 'https://api.anthropic.com/v1/messages',
google: 'https://generativelanguage.googleapis.com/v1beta/models',
} as const;

export const DEFAULT_COPILOT_TEMPERATURE = 0.1 as const;
7 changes: 7 additions & 0 deletions src/constants/provider/google.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import {GoogleModel} from '../../types';

export const MAX_TOKENS_BY_ANTHROPIC_MODEL: Record<GoogleModel, number> = {
'gemini-1.5-flash': 8192,
'gemini-1.5-flash-8b': 8192,
'gemini-1.0-pro': 8192,
} as const;
67 changes: 61 additions & 6 deletions src/helpers/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import {
} from '../types';

const openaiHandler: ProviderHandler<'openai'> = {
/* Model and API key unused in this handler */
createEndpoint: () => COPILOT_PROVIDER_ENDPOINT_MAP.openai,

createRequestBody: (model, prompt) => {
const isO1Model = model === 'o1-preview' || model === 'o1-mini';
const messages = isO1Model
Expand Down Expand Up @@ -45,6 +48,9 @@ const openaiHandler: ProviderHandler<'openai'> = {
};

const groqHandler: ProviderHandler<'groq'> = {
/* Model and API key unused in this handler */
createEndpoint: () => COPILOT_PROVIDER_ENDPOINT_MAP.groq,

createRequestBody: (model, prompt) => ({
model: getModelId(model),
temperature: DEFAULT_COPILOT_TEMPERATURE,
Expand All @@ -68,6 +74,9 @@ const groqHandler: ProviderHandler<'groq'> = {
};

const anthropicHandler: ProviderHandler<'anthropic'> = {
/* Model and API key unused in this handler */
createEndpoint: () => COPILOT_PROVIDER_ENDPOINT_MAP.anthropic,

createRequestBody: (model, prompt) => ({
model: getModelId(model),
temperature: DEFAULT_COPILOT_TEMPERATURE,
Expand Down Expand Up @@ -102,13 +111,65 @@ const anthropicHandler: ProviderHandler<'anthropic'> = {
},
};

const googleHandler: ProviderHandler<'google'> = {
createEndpoint: (model, apiKey) =>
`${COPILOT_PROVIDER_ENDPOINT_MAP.google}/${model}:generateContent?key=${apiKey}`,

createRequestBody: (model, prompt) => ({
model: getModelId(model),
system_instruction: {
parts: {text: prompt.system},
},
contents: [
{
parts: {text: prompt.user},
},
],
}),

/* No API key is needed for this provider in the headers */
createHeaders: () => ({
'Content-Type': 'application/json',
}),

parseCompletion: completion => {
if (
!completion.candidates?.length ||
!completion.candidates[0]?.content ||
!completion.candidates[0].content?.parts?.length
) {
return null;
}

const content = completion.candidates[0].content;

return 'text' in content.parts[0] &&
typeof content.parts[0].text === 'string'
? content.parts[0].text
: null;
},
};

const providerHandlers: Record<
CopilotProvider,
ProviderHandler<CopilotProvider>
> = {
openai: openaiHandler,
groq: groqHandler,
anthropic: anthropicHandler,
google: googleHandler,
};

/**
* Creates an endpoint for different copilot providers.
*/
export const createProviderEndpoint = (
model: CopilotModel,
apiKey: string,
provider: CopilotProvider,
): string => {
const handler = providerHandlers[provider];
return handler.createEndpoint(model, apiKey);
};

/**
Expand Down Expand Up @@ -150,12 +211,6 @@ export const parseProviderChatCompletion = (
*/
const getModelId = (model: CopilotModel): string => COPILOT_MODEL_IDS[model];

/**
* Gets the copilot endpoint for a given provider.
*/
export const getCopilotProviderEndpoint = (provider: CopilotProvider): string =>
COPILOT_PROVIDER_ENDPOINT_MAP[provider];

/**
* Computes the maximum number of tokens for Anthropic models.
*/
Expand Down
37 changes: 30 additions & 7 deletions src/types/copilot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ import type {
MessageCreateParams as AnthropicChatCompletionCreateParamsBase,
Message as AnthropicChatCompletionType,
} from '@anthropic-ai/sdk/resources';
import {
ModelParams as GoogleChatCompletionCreateParamsBase,
GenerateContentResponse as GoogleChatCompletionType,
} from '@google/generative-ai';
import type {
ChatCompletionCreateParamsBase as GroqChatCompletionCreateParamsBase,
ChatCompletion as GroqChatCompletionType,
Expand All @@ -17,29 +21,42 @@ export type AnthropicModel =
| 'claude-3-5-sonnet'
| 'claude-3-5-haiku'
| 'claude-3-haiku';
export type GoogleModel =
| 'gemini-1.5-flash'
| 'gemini-1.5-flash-8b'
| 'gemini-1.0-pro';

export type CopilotModel = OpenAIModel | GroqModel | AnthropicModel;
export type CopilotModel =
| OpenAIModel
| GroqModel
| AnthropicModel
| GoogleModel;

export type PickCopilotModel<T extends CopilotProvider> = T extends 'openai'
? OpenAIModel
: T extends 'groq'
? GroqModel
: T extends 'anthropic'
? AnthropicModel
: never;
: T extends 'google'
? GoogleModel
: never;

export type CopilotProvider = 'openai' | 'groq' | 'anthropic';
export type CopilotProvider = 'openai' | 'groq' | 'anthropic' | 'google';

export type ChatCompletionCreateParams =
| OpenAIChatCompletionCreateParamsBase
| GroqChatCompletionCreateParamsBase
| AnthropicChatCompletionCreateParamsBase;
| AnthropicChatCompletionCreateParamsBase
| GoogleChatCompletionCreateParams;

export type OpenAIChatCompletionCreateParams =
OpenAIChatCompletionCreateParamsBase;
export type GroqChatCompletionCreateParams = GroqChatCompletionCreateParamsBase;
export type AnthropicChatCompletionCreateParams =
AnthropicChatCompletionCreateParamsBase;
export type GoogleChatCompletionCreateParams =
GoogleChatCompletionCreateParamsBase;

export type PickChatCompletionCreateParams<T extends CopilotProvider> =
T extends 'openai'
Expand All @@ -48,24 +65,30 @@ export type PickChatCompletionCreateParams<T extends CopilotProvider> =
? GroqChatCompletionCreateParamsBase
: T extends 'anthropic'
? AnthropicChatCompletionCreateParamsBase
: never;
: T extends 'google'
? GoogleChatCompletionCreateParamsBase
: never;

export type ChatCompletion =
| OpenAIChatCompletion
| GroqChatCompletion
| AnthropicChatCompletion;
| AnthropicChatCompletion
| GoogleChatCompletionType;

export type OpenAIChatCompletion = OpenAIChatCompletionType;
export type GroqChatCompletion = GroqChatCompletionType;
export type AnthropicChatCompletion = AnthropicChatCompletionType;
export type GoogleChatCompletion = GoogleChatCompletionType;

export type PickChatCompletion<T extends CopilotProvider> = T extends 'openai'
? OpenAIChatCompletion
: T extends 'groq'
? GroqChatCompletion
: T extends 'anthropic'
? AnthropicChatCompletion
: never;
: T extends 'google'
? GoogleChatCompletionType
: never;

export type PromptData = {
system: string;
Expand Down
3 changes: 1 addition & 2 deletions src/types/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@ import {
} from './copilot';

export interface ProviderHandler<T extends CopilotProvider> {
createEndpoint(model: PickCopilotModel<T>, apiKey: string): string;
createRequestBody(
model: PickCopilotModel<T>,
prompt: PromptData,
): PickChatCompletionCreateParams<T>;

createHeaders(apiKey: string): Record<string, string>;

parseCompletion(completion: PickChatCompletion<T>): string | null;
}
3 changes: 2 additions & 1 deletion src/utils/http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ const request = async <
});

if (!response.ok) {
const data = '\n' + (JSON.stringify(await response.json(), null, 2) || '');
throw new Error(
`${response.statusText || options.fallbackError || 'Network error'}`,
`${response.statusText || options.fallbackError || 'Network error'}${data}`,
);
}

Expand Down
Binary file added tests/ui/bun.lockb
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/ui/src/app/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ export default function Home() {
return (
<Editor
language="javascript"
onMount={(editor, monaco) => {
onMount={(editor: StandaloneCodeEditor, monaco: Monaco) => {
setMonaco(monaco);
setEditor(editor);
}}
Expand Down