feat: add anthropic provider

This commit is contained in:
kangfenmao 2024-07-19 15:34:34 +08:00
parent c4394b925d
commit 31284a6e23
12 changed files with 284 additions and 101 deletions

View File

@ -29,6 +29,7 @@
"electron-window-state": "^5.0.3" "electron-window-state": "^5.0.3"
}, },
"devDependencies": { "devDependencies": {
"@anthropic-ai/sdk": "^0.24.3",
"@electron-toolkit/eslint-config-prettier": "^2.0.0", "@electron-toolkit/eslint-config-prettier": "^2.0.0",
"@electron-toolkit/eslint-config-ts": "^1.0.1", "@electron-toolkit/eslint-config-ts": "^1.0.1",
"@electron-toolkit/tsconfig": "^1.0.1", "@electron-toolkit/tsconfig": "^1.0.1",

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.8 KiB

View File

@ -386,5 +386,35 @@ export const SYSTEM_MODELS: Record<string, SystemModel[]> = {
group: 'Gemma', group: 'Gemma',
enabled: false enabled: false
} }
],
anthropic: [
{
id: 'claude-3-5-sonnet-20240620',
provider: 'anthropic',
name: 'Claude 3.5 Sonnet',
group: 'Claude 3.5',
enabled: true
},
{
id: 'claude-3-opus-20240229',
provider: 'anthropic',
name: 'Claude 3 Opus',
group: 'Claude 3',
enabled: true
},
{
id: 'claude-3-sonnet-20240229',
provider: 'anthropic',
name: 'Claude 3 Sonnet',
group: 'Claude 3',
enabled: true
},
{
id: 'claude-3-haiku-20240307',
provider: 'anthropic',
name: 'Claude 3 Haiku',
group: 'Claude 3',
enabled: true
}
] ]
} }

View File

@ -9,6 +9,7 @@ import MoonshotProviderLogo from '@renderer/assets/images/providers/moonshot.jpe
import OpenRouterProviderLogo from '@renderer/assets/images/providers/openrouter.png' import OpenRouterProviderLogo from '@renderer/assets/images/providers/openrouter.png'
import BaichuanProviderLogo from '@renderer/assets/images/providers/baichuan.png' import BaichuanProviderLogo from '@renderer/assets/images/providers/baichuan.png'
import DashScopeProviderLogo from '@renderer/assets/images/providers/dashscope.png' import DashScopeProviderLogo from '@renderer/assets/images/providers/dashscope.png'
import AnthropicProviderLogo from '@renderer/assets/images/providers/anthropic.jpeg'
import ChatGPTModelLogo from '@renderer/assets/images/models/chatgpt.jpeg' import ChatGPTModelLogo from '@renderer/assets/images/models/chatgpt.jpeg'
import ChatGLMModelLogo from '@renderer/assets/images/models/chatglm.jpeg' import ChatGLMModelLogo from '@renderer/assets/images/models/chatglm.jpeg'
import DeepSeekModelLogo from '@renderer/assets/images/models/deepseek.png' import DeepSeekModelLogo from '@renderer/assets/images/models/deepseek.png'
@ -20,6 +21,7 @@ import MixtralModelLogo from '@renderer/assets/images/models/mixtral.jpeg'
import MoonshotModelLogo from '@renderer/assets/images/providers/moonshot.jpeg' import MoonshotModelLogo from '@renderer/assets/images/providers/moonshot.jpeg'
import MicrosoftModelLogo from '@renderer/assets/images/models/microsoft.png' import MicrosoftModelLogo from '@renderer/assets/images/models/microsoft.png'
import BaichuanModelLogo from '@renderer/assets/images/models/baichuan.png' import BaichuanModelLogo from '@renderer/assets/images/models/baichuan.png'
import ClaudeModelLogo from '@renderer/assets/images/models/claude.png'
export function getProviderLogo(providerId: string) { export function getProviderLogo(providerId: string) {
switch (providerId) { switch (providerId) {
@ -45,6 +47,8 @@ export function getProviderLogo(providerId: string) {
return BaichuanProviderLogo return BaichuanProviderLogo
case 'dashscope': case 'dashscope':
return DashScopeProviderLogo return DashScopeProviderLogo
case 'anthropic':
return AnthropicProviderLogo
default: default:
return undefined return undefined
} }
@ -63,7 +67,8 @@ export function getModelLogo(modelId: string) {
mistral: MixtralModelLogo, mistral: MixtralModelLogo,
moonshot: MoonshotModelLogo, moonshot: MoonshotModelLogo,
phi: MicrosoftModelLogo, phi: MicrosoftModelLogo,
baichuan: BaichuanModelLogo baichuan: BaichuanModelLogo,
claude: ClaudeModelLogo
} }
for (const key in logoMap) { for (const key in logoMap) {
@ -162,5 +167,13 @@ export const PROVIDER_CONFIG = {
docs: 'https://github.com/ollama/ollama/tree/main/docs', docs: 'https://github.com/ollama/ollama/tree/main/docs',
models: 'https://ollama.com/library' models: 'https://ollama.com/library'
} }
},
anthropic: {
websites: {
official: 'https://anthropic.com/',
apiKey: 'https://console.anthropic.com/settings/keys',
docs: 'https://docs.anthropic.com/en/docs',
models: 'https://docs.anthropic.com/en/docs/about-claude/models'
}
} }
} }

View File

@ -80,7 +80,8 @@ const resources = {
groq: 'Groq', groq: 'Groq',
ollama: 'Ollama', ollama: 'Ollama',
baichuan: 'Baichuan', baichuan: 'Baichuan',
dashscope: 'DashScope' dashscope: 'DashScope',
anthropic: 'Anthropic'
}, },
settings: { settings: {
title: 'Settings', title: 'Settings',
@ -197,7 +198,8 @@ const resources = {
groq: 'Groq', groq: 'Groq',
ollama: 'Ollama', ollama: 'Ollama',
baichuan: '百川', baichuan: '百川',
dashscope: '阿里云灵积' dashscope: '阿里云灵积',
anthropic: 'Anthropic'
}, },
settings: { settings: {
title: '设置', title: '设置',

View File

@ -0,0 +1,142 @@
import { Assistant, Message, Provider } from '@renderer/types'
import OpenAI from 'openai'
import Anthropic from '@anthropic-ai/sdk'
import { getDefaultModel, getTopNamingModel } from './assistant'
import {
ChatCompletionCreateParamsNonStreaming,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam
} from 'openai/resources'
import { sum, takeRight } from 'lodash'
import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources'
import { EVENT_NAMES } from './event'
import { removeQuotes } from '@renderer/utils'
export default class ProviderSDK {
provider: Provider
openaiSdk: OpenAI
anthropicSdk: Anthropic
constructor(provider: Provider) {
this.provider = provider
const host = provider.apiHost
const baseURL = host.endsWith('/') ? host : `${provider.apiHost}/v1/`
this.anthropicSdk = new Anthropic({ apiKey: provider.apiKey, baseURL })
this.openaiSdk = new OpenAI({ dangerouslyAllowBrowser: true, apiKey: provider.apiKey, baseURL })
}
private get isAnthropic() {
return this.provider.id === 'anthropic'
}
public async completions(
messages: Message[],
assistant: Assistant,
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
) {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined
const userMessages = takeRight(messages, 5).map((message) => ({
role: message.role,
content: message.content
}))
if (this.isAnthropic) {
await this.anthropicSdk.messages
.stream({
max_tokens: 1024,
messages: [systemMessage, ...userMessages].filter(Boolean) as MessageParam[],
model: model.id
})
.on('text', (text) => onChunk({ text: text || '' }))
.on('finalMessage', (message) =>
onChunk({
usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: sum(Object.values(message.usage)) }
})
)
} else {
const stream = await this.openaiSdk.chat.completions.create({
model: model.id,
messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[],
stream: true
})
for await (const chunk of stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
break
}
onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage })
}
}
}
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
const model = getTopNamingModel() || assistant.model || getDefaultModel()
const userMessages: ChatCompletionMessageParam[] = takeRight(messages, 5).map((message) => ({
role: 'user',
content: message.content
}))
const systemMessage: ChatCompletionSystemMessageParam = {
role: 'system',
content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要加标点符号'
}
if (this.isAnthropic) {
const message = await this.anthropicSdk.messages.create({
messages: [systemMessage, ...userMessages] as Anthropic.Messages.MessageParam[],
model: model.id,
stream: false,
max_tokens: 50
})
return message.content[0].type === 'text' ? message.content[0].text : null
} else {
const response = await this.openaiSdk.chat.completions.create({
model: model.id,
messages: [systemMessage, ...userMessages],
stream: false
})
return removeQuotes(response.choices[0].message?.content || '')
}
}
public async check(): Promise<{ valid: boolean; error: Error | null }> {
const model = this.provider.models[0]
const body = {
model: model.id,
messages: [{ role: 'user', content: 'hi' }],
max_tokens: 100,
stream: false
}
try {
if (this.isAnthropic) {
const message = await this.anthropicSdk.messages.create(body as MessageCreateParamsNonStreaming)
return { valid: message.content.length > 0, error: null }
} else {
const response = await this.openaiSdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming)
return { valid: Boolean(response?.choices[0].message), error: null }
}
} catch (error: any) {
return { valid: false, error }
}
}
public async models(): Promise<OpenAI.Models.Model[]> {
try {
if (this.isAnthropic) {
return []
}
const response = await this.openaiSdk.models.list()
return response.data
} catch (error) {
return []
}
}
}

View File

@ -1,38 +1,30 @@
import { Assistant, Message, Provider, Topic } from '@renderer/types' import i18n from '@renderer/i18n'
import { uuid } from '@renderer/utils'
import { EVENT_NAMES, EventEmitter } from './event'
import { ChatCompletionMessageParam, ChatCompletionSystemMessageParam } from 'openai/resources'
import OpenAI from 'openai'
import { getAssistantProvider, getDefaultModel, getProviderByModel, getTopNamingModel } from './assistant'
import { takeRight } from 'lodash'
import dayjs from 'dayjs'
import store from '@renderer/store' import store from '@renderer/store'
import { setGenerating } from '@renderer/store/runtime' import { setGenerating } from '@renderer/store/runtime'
import i18n from '@renderer/i18n' import { Assistant, Message, Provider, Topic } from '@renderer/types'
import { getErrorMessage, uuid } from '@renderer/utils'
import dayjs from 'dayjs'
import { getAssistantProvider, getDefaultModel, getProviderByModel, getTopNamingModel } from './assistant'
import { EVENT_NAMES, EventEmitter } from './event'
import ProviderSDK from './ProviderSDK'
interface FetchChatCompletionParams { export async function fetchChatCompletion({
messages,
topic,
assistant,
onResponse
}: {
messages: Message[] messages: Message[]
topic: Topic topic: Topic
assistant: Assistant assistant: Assistant
onResponse: (message: Message) => void onResponse: (message: Message) => void
} }) {
const getOpenAiProvider = (provider: Provider) => {
const host = provider.apiHost
return new OpenAI({
dangerouslyAllowBrowser: true,
apiKey: provider.apiKey,
baseURL: host.endsWith('/') ? host : `${provider.apiHost}/v1/`
})
}
export async function fetchChatCompletion({ messages, topic, assistant, onResponse }: FetchChatCompletionParams) {
window.keyv.set(EVENT_NAMES.CHAT_COMPLETION_PAUSED, false) window.keyv.set(EVENT_NAMES.CHAT_COMPLETION_PAUSED, false)
const provider = getAssistantProvider(assistant) const provider = getAssistantProvider(assistant)
const openaiProvider = getOpenAiProvider(provider)
const defaultModel = getDefaultModel() const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel const model = assistant.model || defaultModel
const providerSdk = new ProviderSDK(provider)
store.dispatch(setGenerating(true)) store.dispatch(setGenerating(true))
@ -49,79 +41,36 @@ export async function fetchChatCompletion({ messages, topic, assistant, onRespon
onResponse({ ...message }) onResponse({ ...message })
const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined
const userMessages = takeRight(messages, 5).map((message) => ({
role: message.role,
content: message.content
}))
try { try {
const stream = await openaiProvider.chat.completions.create({ await providerSdk.completions(messages, assistant, ({ text, usage }) => {
model: model.id, message.content = message.content + text || ''
messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[], message.usage = usage
stream: true onResponse({ ...message, status: 'pending' })
}) })
let content = ''
let usage: OpenAI.Completions.CompletionUsage | undefined = undefined
for await (const chunk of stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
break
}
content = content + (chunk.choices[0]?.delta?.content || '')
chunk.usage && (usage = chunk.usage)
onResponse({ ...message, content, status: 'pending' })
}
message.content = content
message.usage = usage
} catch (error: any) { } catch (error: any) {
message.content = `Error: ${error.message}` message.content = `Error: ${error.message}`
} }
const paused = window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED) // Update message status
message.status = paused ? 'paused' : 'success' message.status = window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED) ? 'paused' : 'success'
// Emit chat completion event
EventEmitter.emit(EVENT_NAMES.AI_CHAT_COMPLETION, message) EventEmitter.emit(EVENT_NAMES.AI_CHAT_COMPLETION, message)
// Reset generating state
store.dispatch(setGenerating(false)) store.dispatch(setGenerating(false))
return message return message
} }
interface FetchMessagesSummaryParams { export async function fetchMessagesSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
messages: Message[]
assistant: Assistant
}
export async function fetchMessagesSummary({ messages, assistant }: FetchMessagesSummaryParams) {
const model = getTopNamingModel() || assistant.model || getDefaultModel() const model = getTopNamingModel() || assistant.model || getDefaultModel()
const provider = getProviderByModel(model) const provider = getProviderByModel(model)
const openaiProvider = getOpenAiProvider(provider) const providerSdk = new ProviderSDK(provider)
return providerSdk.summaries(messages, assistant)
const userMessages: ChatCompletionMessageParam[] = takeRight(messages, 5).map((message) => ({
role: 'user',
content: message.content
}))
const systemMessage: ChatCompletionSystemMessageParam = {
role: 'system',
content:
'你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,回复内容不需要用引号引起来,不需要在结尾加上句号。'
}
const response = await openaiProvider.chat.completions.create({
model: model.id,
messages: [systemMessage, ...userMessages],
stream: false
})
return response.choices[0].message?.content
} }
export async function checkApi(provider: Provider) { export async function checkApi(provider: Provider) {
const openaiProvider = getOpenAiProvider(provider)
const model = provider.models[0] const model = provider.models[0]
const key = 'api-check' const key = 'api-check'
const style = { marginTop: '3vh' } const style = { marginTop: '3vh' }
@ -141,22 +90,9 @@ export async function checkApi(provider: Provider) {
return false return false
} }
let valid = false const providerSdk = new ProviderSDK(provider)
let errorMessage = ''
try { const { valid, error } = await providerSdk.check()
const response = await openaiProvider.chat.completions.create({
model: model.id,
messages: [{ role: 'user', content: 'hi' }],
max_tokens: 100,
stream: false
})
valid = Boolean(response?.choices[0].message)
} catch (error) {
errorMessage = (error as Error).message
valid = false
}
window.message[valid ? 'success' : 'error']({ window.message[valid ? 'success' : 'error']({
key: 'api-check', key: 'api-check',
@ -164,17 +100,17 @@ export async function checkApi(provider: Provider) {
duration: valid ? 2 : 8, duration: valid ? 2 : 8,
content: valid content: valid
? i18n.t('message.api.connection.success') ? i18n.t('message.api.connection.success')
: i18n.t('message.api.connection.failed') + ' ' + errorMessage : i18n.t('message.api.connection.failed') + ' : ' + getErrorMessage(error)
}) })
return valid return valid
} }
export async function fetchModels(provider: Provider) { export async function fetchModels(provider: Provider) {
const providerSdk = new ProviderSDK(provider)
try { try {
const openaiProvider = getOpenAiProvider(provider) return await providerSdk.models()
const response = await openaiProvider.models.list()
return response.data
} catch (error) { } catch (error) {
return [] return []
} }

View File

@ -85,6 +85,15 @@ const initialState: LlmState = {
isSystem: true, isSystem: true,
enabled: false enabled: false
}, },
{
id: 'anthropic',
name: 'Anthropic',
apiKey: '',
apiHost: 'https://api.anthropic.com/',
models: SYSTEM_MODELS.anthropic.filter((m) => m.enabled),
isSystem: true,
enabled: false
},
{ {
id: 'openrouter', id: 'openrouter',
name: 'OpenRouter', name: 'OpenRouter',

View File

@ -194,6 +194,15 @@ const migrate = createMigrate({
models: SYSTEM_MODELS.dashscope.filter((m) => m.enabled), models: SYSTEM_MODELS.dashscope.filter((m) => m.enabled),
isSystem: true, isSystem: true,
enabled: false enabled: false
},
{
id: 'anthropic',
name: 'Anthropic',
apiKey: '',
apiHost: 'https://api.anthropic.com/',
models: SYSTEM_MODELS.anthropic.filter((m) => m.enabled),
isSystem: true,
enabled: false
} }
] ]
} }

View File

@ -108,3 +108,27 @@ export async function isDev() {
const isProd = await isProduction() const isProd = await isProduction()
return !isProd return !isProd
} }
export function getErrorMessage(error: any) {
if (!error) {
return ''
}
if (typeof error === 'string') {
return error
}
if (error?.error) {
return getErrorMessage(error.error)
}
if (error?.message) {
return error.message
}
return ''
}
export function removeQuotes(str) {
return str.replace(/['"]+/g, '')
}

View File

@ -87,6 +87,22 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"@anthropic-ai/sdk@npm:^0.24.3":
version: 0.24.3
resolution: "@anthropic-ai/sdk@npm:0.24.3"
dependencies:
"@types/node": "npm:^18.11.18"
"@types/node-fetch": "npm:^2.6.4"
abort-controller: "npm:^3.0.0"
agentkeepalive: "npm:^4.2.1"
form-data-encoder: "npm:1.7.2"
formdata-node: "npm:^4.3.2"
node-fetch: "npm:^2.6.7"
web-streams-polyfill: "npm:^3.2.1"
checksum: 10c0/1c73c3df9637522da548d2cddfaf89513dac935c5cdb7c0b3db1c427c069a0de76df935bd189e477822063e9f944360e2d059827d5be4dca33bd388c61e97a30
languageName: node
linkType: hard
"@babel/code-frame@npm:^7.23.5, @babel/code-frame@npm:^7.24.2": "@babel/code-frame@npm:^7.23.5, @babel/code-frame@npm:^7.24.2":
version: 7.24.2 version: 7.24.2
resolution: "@babel/code-frame@npm:7.24.2" resolution: "@babel/code-frame@npm:7.24.2"
@ -3391,6 +3407,7 @@ __metadata:
version: 0.0.0-use.local version: 0.0.0-use.local
resolution: "cherry-studio@workspace:." resolution: "cherry-studio@workspace:."
dependencies: dependencies:
"@anthropic-ai/sdk": "npm:^0.24.3"
"@electron-toolkit/eslint-config-prettier": "npm:^2.0.0" "@electron-toolkit/eslint-config-prettier": "npm:^2.0.0"
"@electron-toolkit/eslint-config-ts": "npm:^1.0.1" "@electron-toolkit/eslint-config-ts": "npm:^1.0.1"
"@electron-toolkit/preload": "npm:^3.0.0" "@electron-toolkit/preload": "npm:^3.0.0"