diff --git a/package.json b/package.json index 8d7baa10..09b10084 100644 --- a/package.json +++ b/package.json @@ -29,6 +29,7 @@ "electron-window-state": "^5.0.3" }, "devDependencies": { + "@anthropic-ai/sdk": "^0.24.3", "@electron-toolkit/eslint-config-prettier": "^2.0.0", "@electron-toolkit/eslint-config-ts": "^1.0.1", "@electron-toolkit/tsconfig": "^1.0.1", diff --git a/src/renderer/src/assets/images/models/claude.png b/src/renderer/src/assets/images/models/claude.png new file mode 100644 index 00000000..e6213255 Binary files /dev/null and b/src/renderer/src/assets/images/models/claude.png differ diff --git a/src/renderer/src/assets/images/providers/anthropic.jpeg b/src/renderer/src/assets/images/providers/anthropic.jpeg new file mode 100644 index 00000000..6cb2e6ab Binary files /dev/null and b/src/renderer/src/assets/images/providers/anthropic.jpeg differ diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index 33d1503e..e9ade2e4 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -386,5 +386,35 @@ export const SYSTEM_MODELS: Record = { group: 'Gemma', enabled: false } + ], + anthropic: [ + { + id: 'claude-3-5-sonnet-20240620', + provider: 'anthropic', + name: 'Claude 3.5 Sonnet', + group: 'Claude 3.5', + enabled: true + }, + { + id: 'claude-3-opus-20240229', + provider: 'anthropic', + name: 'Claude 3 Opus', + group: 'Claude 3', + enabled: true + }, + { + id: 'claude-3-sonnet-20240229', + provider: 'anthropic', + name: 'Claude 3 Sonnet', + group: 'Claude 3', + enabled: true + }, + { + id: 'claude-3-haiku-20240307', + provider: 'anthropic', + name: 'Claude 3 Haiku', + group: 'Claude 3', + enabled: true + } ] } diff --git a/src/renderer/src/config/provider.ts b/src/renderer/src/config/provider.ts index c271ddcd..8d2bdf28 100644 --- a/src/renderer/src/config/provider.ts +++ b/src/renderer/src/config/provider.ts @@ -9,6 +9,7 @@ import MoonshotProviderLogo from '@renderer/assets/images/providers/moonshot.jpe import OpenRouterProviderLogo from '@renderer/assets/images/providers/openrouter.png' import BaichuanProviderLogo from '@renderer/assets/images/providers/baichuan.png' import DashScopeProviderLogo from '@renderer/assets/images/providers/dashscope.png' +import AnthropicProviderLogo from '@renderer/assets/images/providers/anthropic.jpeg' import ChatGPTModelLogo from '@renderer/assets/images/models/chatgpt.jpeg' import ChatGLMModelLogo from '@renderer/assets/images/models/chatglm.jpeg' import DeepSeekModelLogo from '@renderer/assets/images/models/deepseek.png' @@ -20,6 +21,7 @@ import MixtralModelLogo from '@renderer/assets/images/models/mixtral.jpeg' import MoonshotModelLogo from '@renderer/assets/images/providers/moonshot.jpeg' import MicrosoftModelLogo from '@renderer/assets/images/models/microsoft.png' import BaichuanModelLogo from '@renderer/assets/images/models/baichuan.png' +import ClaudeModelLogo from '@renderer/assets/images/models/claude.png' export function getProviderLogo(providerId: string) { switch (providerId) { @@ -45,6 +47,8 @@ export function getProviderLogo(providerId: string) { return BaichuanProviderLogo case 'dashscope': return DashScopeProviderLogo + case 'anthropic': + return AnthropicProviderLogo default: return undefined } @@ -63,7 +67,8 @@ export function getModelLogo(modelId: string) { mistral: MixtralModelLogo, moonshot: MoonshotModelLogo, phi: MicrosoftModelLogo, - baichuan: BaichuanModelLogo + baichuan: BaichuanModelLogo, + claude: ClaudeModelLogo } for (const key in logoMap) { @@ -162,5 +167,13 @@ export const PROVIDER_CONFIG = { docs: 'https://github.com/ollama/ollama/tree/main/docs', models: 'https://ollama.com/library' } + }, + anthropic: { + websites: { + official: 'https://anthropic.com/', + apiKey: 'https://console.anthropic.com/settings/keys', + docs: 'https://docs.anthropic.com/en/docs', + models: 'https://docs.anthropic.com/en/docs/about-claude/models' + } } } diff --git a/src/renderer/src/i18n/index.ts b/src/renderer/src/i18n/index.ts index a51118d0..ba54df38 100644 --- a/src/renderer/src/i18n/index.ts +++ b/src/renderer/src/i18n/index.ts @@ -80,7 +80,8 @@ const resources = { groq: 'Groq', ollama: 'Ollama', baichuan: 'Baichuan', - dashscope: 'DashScope' + dashscope: 'DashScope', + anthropic: 'Anthropic' }, settings: { title: 'Settings', @@ -197,7 +198,8 @@ const resources = { groq: 'Groq', ollama: 'Ollama', baichuan: '百川', - dashscope: '阿里云灵积' + dashscope: '阿里云灵积', + anthropic: 'Anthropic' }, settings: { title: '设置', diff --git a/src/renderer/src/services/ProviderSDK.ts b/src/renderer/src/services/ProviderSDK.ts new file mode 100644 index 00000000..2f15cdc2 --- /dev/null +++ b/src/renderer/src/services/ProviderSDK.ts @@ -0,0 +1,142 @@ +import { Assistant, Message, Provider } from '@renderer/types' +import OpenAI from 'openai' +import Anthropic from '@anthropic-ai/sdk' +import { getDefaultModel, getTopNamingModel } from './assistant' +import { + ChatCompletionCreateParamsNonStreaming, + ChatCompletionMessageParam, + ChatCompletionSystemMessageParam +} from 'openai/resources' +import { sum, takeRight } from 'lodash' +import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources' +import { EVENT_NAMES } from './event' +import { removeQuotes } from '@renderer/utils' + +export default class ProviderSDK { + provider: Provider + openaiSdk: OpenAI + anthropicSdk: Anthropic + + constructor(provider: Provider) { + this.provider = provider + const host = provider.apiHost + const baseURL = host.endsWith('/') ? host : `${provider.apiHost}/v1/` + this.anthropicSdk = new Anthropic({ apiKey: provider.apiKey, baseURL }) + this.openaiSdk = new OpenAI({ dangerouslyAllowBrowser: true, apiKey: provider.apiKey, baseURL }) + } + + private get isAnthropic() { + return this.provider.id === 'anthropic' + } + + public async completions( + messages: Message[], + assistant: Assistant, + onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void + ) { + const defaultModel = getDefaultModel() + const model = assistant.model || defaultModel + + const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined + + const userMessages = takeRight(messages, 5).map((message) => ({ + role: message.role, + content: message.content + })) + + if (this.isAnthropic) { + await this.anthropicSdk.messages + .stream({ + max_tokens: 1024, + messages: [systemMessage, ...userMessages].filter(Boolean) as MessageParam[], + model: model.id + }) + .on('text', (text) => onChunk({ text: text || '' })) + .on('finalMessage', (message) => + onChunk({ + usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: sum(Object.values(message.usage)) } + }) + ) + } else { + const stream = await this.openaiSdk.chat.completions.create({ + model: model.id, + messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[], + stream: true + }) + for await (const chunk of stream) { + if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) { + break + } + onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage }) + } + } + } + + public async summaries(messages: Message[], assistant: Assistant): Promise { + const model = getTopNamingModel() || assistant.model || getDefaultModel() + + const userMessages: ChatCompletionMessageParam[] = takeRight(messages, 5).map((message) => ({ + role: 'user', + content: message.content + })) + + const systemMessage: ChatCompletionSystemMessageParam = { + role: 'system', + content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要加标点符号' + } + + if (this.isAnthropic) { + const message = await this.anthropicSdk.messages.create({ + messages: [systemMessage, ...userMessages] as Anthropic.Messages.MessageParam[], + model: model.id, + stream: false, + max_tokens: 50 + }) + + return message.content[0].type === 'text' ? message.content[0].text : null + } else { + const response = await this.openaiSdk.chat.completions.create({ + model: model.id, + messages: [systemMessage, ...userMessages], + stream: false + }) + + return removeQuotes(response.choices[0].message?.content || '') + } + } + + public async check(): Promise<{ valid: boolean; error: Error | null }> { + const model = this.provider.models[0] + const body = { + model: model.id, + messages: [{ role: 'user', content: 'hi' }], + max_tokens: 100, + stream: false + } + + try { + if (this.isAnthropic) { + const message = await this.anthropicSdk.messages.create(body as MessageCreateParamsNonStreaming) + return { valid: message.content.length > 0, error: null } + } else { + const response = await this.openaiSdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming) + return { valid: Boolean(response?.choices[0].message), error: null } + } + } catch (error: any) { + return { valid: false, error } + } + } + + public async models(): Promise { + try { + if (this.isAnthropic) { + return [] + } + + const response = await this.openaiSdk.models.list() + return response.data + } catch (error) { + return [] + } + } +} diff --git a/src/renderer/src/services/api.ts b/src/renderer/src/services/api.ts index e91f68f2..5539b5a5 100644 --- a/src/renderer/src/services/api.ts +++ b/src/renderer/src/services/api.ts @@ -1,38 +1,30 @@ -import { Assistant, Message, Provider, Topic } from '@renderer/types' -import { uuid } from '@renderer/utils' -import { EVENT_NAMES, EventEmitter } from './event' -import { ChatCompletionMessageParam, ChatCompletionSystemMessageParam } from 'openai/resources' -import OpenAI from 'openai' -import { getAssistantProvider, getDefaultModel, getProviderByModel, getTopNamingModel } from './assistant' -import { takeRight } from 'lodash' -import dayjs from 'dayjs' +import i18n from '@renderer/i18n' import store from '@renderer/store' import { setGenerating } from '@renderer/store/runtime' -import i18n from '@renderer/i18n' +import { Assistant, Message, Provider, Topic } from '@renderer/types' +import { getErrorMessage, uuid } from '@renderer/utils' +import dayjs from 'dayjs' +import { getAssistantProvider, getDefaultModel, getProviderByModel, getTopNamingModel } from './assistant' +import { EVENT_NAMES, EventEmitter } from './event' +import ProviderSDK from './ProviderSDK' -interface FetchChatCompletionParams { +export async function fetchChatCompletion({ + messages, + topic, + assistant, + onResponse +}: { messages: Message[] topic: Topic assistant: Assistant onResponse: (message: Message) => void -} - -const getOpenAiProvider = (provider: Provider) => { - const host = provider.apiHost - return new OpenAI({ - dangerouslyAllowBrowser: true, - apiKey: provider.apiKey, - baseURL: host.endsWith('/') ? host : `${provider.apiHost}/v1/` - }) -} - -export async function fetchChatCompletion({ messages, topic, assistant, onResponse }: FetchChatCompletionParams) { +}) { window.keyv.set(EVENT_NAMES.CHAT_COMPLETION_PAUSED, false) const provider = getAssistantProvider(assistant) - const openaiProvider = getOpenAiProvider(provider) const defaultModel = getDefaultModel() const model = assistant.model || defaultModel + const providerSdk = new ProviderSDK(provider) store.dispatch(setGenerating(true)) @@ -49,79 +41,36 @@ export async function fetchChatCompletion({ messages, topic, assistant, onRespon onResponse({ ...message }) - const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined - - const userMessages = takeRight(messages, 5).map((message) => ({ - role: message.role, - content: message.content - })) - try { - const stream = await openaiProvider.chat.completions.create({ - model: model.id, - messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[], - stream: true + await providerSdk.completions(messages, assistant, ({ text, usage }) => { + message.content = message.content + text || '' + message.usage = usage + onResponse({ ...message, status: 'pending' }) }) - - let content = '' - let usage: OpenAI.Completions.CompletionUsage | undefined = undefined - - for await (const chunk of stream) { - if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) { - break - } - - content = content + (chunk.choices[0]?.delta?.content || '') - chunk.usage && (usage = chunk.usage) - onResponse({ ...message, content, status: 'pending' }) - } - - message.content = content - message.usage = usage } catch (error: any) { message.content = `Error: ${error.message}` } - const paused = window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED) - message.status = paused ? 'paused' : 'success' + // Update message status + message.status = window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED) ? 'paused' : 'success' + + // Emit chat completion event EventEmitter.emit(EVENT_NAMES.AI_CHAT_COMPLETION, message) + + // Reset generating state store.dispatch(setGenerating(false)) return message } -interface FetchMessagesSummaryParams { - messages: Message[] - assistant: Assistant -} - -export async function fetchMessagesSummary({ messages, assistant }: FetchMessagesSummaryParams) { +export async function fetchMessagesSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) { const model = getTopNamingModel() || assistant.model || getDefaultModel() const provider = getProviderByModel(model) - const openaiProvider = getOpenAiProvider(provider) - - const userMessages: ChatCompletionMessageParam[] = takeRight(messages, 5).map((message) => ({ - role: 'user', - content: message.content - })) - - const systemMessage: ChatCompletionSystemMessageParam = { - role: 'system', - content: - '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,回复内容不需要用引号引起来,不需要在结尾加上句号。' - } - - const response = await openaiProvider.chat.completions.create({ - model: model.id, - messages: [systemMessage, ...userMessages], - stream: false - }) - - return response.choices[0].message?.content + const providerSdk = new ProviderSDK(provider) + return providerSdk.summaries(messages, assistant) } export async function checkApi(provider: Provider) { - const openaiProvider = getOpenAiProvider(provider) const model = provider.models[0] const key = 'api-check' const style = { marginTop: '3vh' } @@ -141,22 +90,9 @@ export async function checkApi(provider: Provider) { return false } - let valid = false - let errorMessage = '' + const providerSdk = new ProviderSDK(provider) - try { - const response = await openaiProvider.chat.completions.create({ - model: model.id, - messages: [{ role: 'user', content: 'hi' }], - max_tokens: 100, - stream: false - }) - - valid = Boolean(response?.choices[0].message) - } catch (error) { - errorMessage = (error as Error).message - valid = false - } + const { valid, error } = await providerSdk.check() window.message[valid ? 'success' : 'error']({ key: 'api-check', @@ -164,17 +100,17 @@ export async function checkApi(provider: Provider) { duration: valid ? 2 : 8, content: valid ? i18n.t('message.api.connection.success') - : i18n.t('message.api.connection.failed') + ' ' + errorMessage + : i18n.t('message.api.connection.failed') + ' : ' + getErrorMessage(error) }) return valid } export async function fetchModels(provider: Provider) { + const providerSdk = new ProviderSDK(provider) + try { - const openaiProvider = getOpenAiProvider(provider) - const response = await openaiProvider.models.list() - return response.data + return await providerSdk.models() } catch (error) { return [] } diff --git a/src/renderer/src/store/llm.ts b/src/renderer/src/store/llm.ts index fbe5fe7f..e248b5f9 100644 --- a/src/renderer/src/store/llm.ts +++ b/src/renderer/src/store/llm.ts @@ -85,6 +85,15 @@ const initialState: LlmState = { isSystem: true, enabled: false }, + { + id: 'anthropic', + name: 'Anthropic', + apiKey: '', + apiHost: 'https://api.anthropic.com/', + models: SYSTEM_MODELS.anthropic.filter((m) => m.enabled), + isSystem: true, + enabled: false + }, { id: 'openrouter', name: 'OpenRouter', diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts index 1a177601..f2aed1e3 100644 --- a/src/renderer/src/store/migrate.ts +++ b/src/renderer/src/store/migrate.ts @@ -194,6 +194,15 @@ const migrate = createMigrate({ models: SYSTEM_MODELS.dashscope.filter((m) => m.enabled), isSystem: true, enabled: false + }, + { + id: 'anthropic', + name: 'Anthropic', + apiKey: '', + apiHost: 'https://api.anthropic.com/', + models: SYSTEM_MODELS.anthropic.filter((m) => m.enabled), + isSystem: true, + enabled: false } ] } diff --git a/src/renderer/src/utils/index.ts b/src/renderer/src/utils/index.ts index f3da50db..02c9d90d 100644 --- a/src/renderer/src/utils/index.ts +++ b/src/renderer/src/utils/index.ts @@ -108,3 +108,27 @@ export async function isDev() { const isProd = await isProduction() return !isProd } + +export function getErrorMessage(error: any) { + if (!error) { + return '' + } + + if (typeof error === 'string') { + return error + } + + if (error?.error) { + return getErrorMessage(error.error) + } + + if (error?.message) { + return error.message + } + + return '' +} + +export function removeQuotes(str) { + return str.replace(/['"]+/g, '') +} diff --git a/yarn.lock b/yarn.lock index cb08ff89..45c6df60 100644 --- a/yarn.lock +++ b/yarn.lock @@ -87,6 +87,22 @@ __metadata: languageName: node linkType: hard +"@anthropic-ai/sdk@npm:^0.24.3": + version: 0.24.3 + resolution: "@anthropic-ai/sdk@npm:0.24.3" + dependencies: + "@types/node": "npm:^18.11.18" + "@types/node-fetch": "npm:^2.6.4" + abort-controller: "npm:^3.0.0" + agentkeepalive: "npm:^4.2.1" + form-data-encoder: "npm:1.7.2" + formdata-node: "npm:^4.3.2" + node-fetch: "npm:^2.6.7" + web-streams-polyfill: "npm:^3.2.1" + checksum: 10c0/1c73c3df9637522da548d2cddfaf89513dac935c5cdb7c0b3db1c427c069a0de76df935bd189e477822063e9f944360e2d059827d5be4dca33bd388c61e97a30 + languageName: node + linkType: hard + "@babel/code-frame@npm:^7.23.5, @babel/code-frame@npm:^7.24.2": version: 7.24.2 resolution: "@babel/code-frame@npm:7.24.2" @@ -3391,6 +3407,7 @@ __metadata: version: 0.0.0-use.local resolution: "cherry-studio@workspace:." dependencies: + "@anthropic-ai/sdk": "npm:^0.24.3" "@electron-toolkit/eslint-config-prettier": "npm:^2.0.0" "@electron-toolkit/eslint-config-ts": "npm:^1.0.1" "@electron-toolkit/preload": "npm:^3.0.0"