feat: add anthropic provider
This commit is contained in:
parent
c4394b925d
commit
31284a6e23
@ -29,6 +29,7 @@
|
||||
"electron-window-state": "^5.0.3"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@anthropic-ai/sdk": "^0.24.3",
|
||||
"@electron-toolkit/eslint-config-prettier": "^2.0.0",
|
||||
"@electron-toolkit/eslint-config-ts": "^1.0.1",
|
||||
"@electron-toolkit/tsconfig": "^1.0.1",
|
||||
|
||||
BIN
src/renderer/src/assets/images/models/claude.png
Normal file
BIN
src/renderer/src/assets/images/models/claude.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 10 KiB |
BIN
src/renderer/src/assets/images/providers/anthropic.jpeg
Normal file
BIN
src/renderer/src/assets/images/providers/anthropic.jpeg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 6.8 KiB |
@ -386,5 +386,35 @@ export const SYSTEM_MODELS: Record<string, SystemModel[]> = {
|
||||
group: 'Gemma',
|
||||
enabled: false
|
||||
}
|
||||
],
|
||||
anthropic: [
|
||||
{
|
||||
id: 'claude-3-5-sonnet-20240620',
|
||||
provider: 'anthropic',
|
||||
name: 'Claude 3.5 Sonnet',
|
||||
group: 'Claude 3.5',
|
||||
enabled: true
|
||||
},
|
||||
{
|
||||
id: 'claude-3-opus-20240229',
|
||||
provider: 'anthropic',
|
||||
name: 'Claude 3 Opus',
|
||||
group: 'Claude 3',
|
||||
enabled: true
|
||||
},
|
||||
{
|
||||
id: 'claude-3-sonnet-20240229',
|
||||
provider: 'anthropic',
|
||||
name: 'Claude 3 Sonnet',
|
||||
group: 'Claude 3',
|
||||
enabled: true
|
||||
},
|
||||
{
|
||||
id: 'claude-3-haiku-20240307',
|
||||
provider: 'anthropic',
|
||||
name: 'Claude 3 Haiku',
|
||||
group: 'Claude 3',
|
||||
enabled: true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -9,6 +9,7 @@ import MoonshotProviderLogo from '@renderer/assets/images/providers/moonshot.jpe
|
||||
import OpenRouterProviderLogo from '@renderer/assets/images/providers/openrouter.png'
|
||||
import BaichuanProviderLogo from '@renderer/assets/images/providers/baichuan.png'
|
||||
import DashScopeProviderLogo from '@renderer/assets/images/providers/dashscope.png'
|
||||
import AnthropicProviderLogo from '@renderer/assets/images/providers/anthropic.jpeg'
|
||||
import ChatGPTModelLogo from '@renderer/assets/images/models/chatgpt.jpeg'
|
||||
import ChatGLMModelLogo from '@renderer/assets/images/models/chatglm.jpeg'
|
||||
import DeepSeekModelLogo from '@renderer/assets/images/models/deepseek.png'
|
||||
@ -20,6 +21,7 @@ import MixtralModelLogo from '@renderer/assets/images/models/mixtral.jpeg'
|
||||
import MoonshotModelLogo from '@renderer/assets/images/providers/moonshot.jpeg'
|
||||
import MicrosoftModelLogo from '@renderer/assets/images/models/microsoft.png'
|
||||
import BaichuanModelLogo from '@renderer/assets/images/models/baichuan.png'
|
||||
import ClaudeModelLogo from '@renderer/assets/images/models/claude.png'
|
||||
|
||||
export function getProviderLogo(providerId: string) {
|
||||
switch (providerId) {
|
||||
@ -45,6 +47,8 @@ export function getProviderLogo(providerId: string) {
|
||||
return BaichuanProviderLogo
|
||||
case 'dashscope':
|
||||
return DashScopeProviderLogo
|
||||
case 'anthropic':
|
||||
return AnthropicProviderLogo
|
||||
default:
|
||||
return undefined
|
||||
}
|
||||
@ -63,7 +67,8 @@ export function getModelLogo(modelId: string) {
|
||||
mistral: MixtralModelLogo,
|
||||
moonshot: MoonshotModelLogo,
|
||||
phi: MicrosoftModelLogo,
|
||||
baichuan: BaichuanModelLogo
|
||||
baichuan: BaichuanModelLogo,
|
||||
claude: ClaudeModelLogo
|
||||
}
|
||||
|
||||
for (const key in logoMap) {
|
||||
@ -162,5 +167,13 @@ export const PROVIDER_CONFIG = {
|
||||
docs: 'https://github.com/ollama/ollama/tree/main/docs',
|
||||
models: 'https://ollama.com/library'
|
||||
}
|
||||
},
|
||||
anthropic: {
|
||||
websites: {
|
||||
official: 'https://anthropic.com/',
|
||||
apiKey: 'https://console.anthropic.com/settings/keys',
|
||||
docs: 'https://docs.anthropic.com/en/docs',
|
||||
models: 'https://docs.anthropic.com/en/docs/about-claude/models'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -80,7 +80,8 @@ const resources = {
|
||||
groq: 'Groq',
|
||||
ollama: 'Ollama',
|
||||
baichuan: 'Baichuan',
|
||||
dashscope: 'DashScope'
|
||||
dashscope: 'DashScope',
|
||||
anthropic: 'Anthropic'
|
||||
},
|
||||
settings: {
|
||||
title: 'Settings',
|
||||
@ -197,7 +198,8 @@ const resources = {
|
||||
groq: 'Groq',
|
||||
ollama: 'Ollama',
|
||||
baichuan: '百川',
|
||||
dashscope: '阿里云灵积'
|
||||
dashscope: '阿里云灵积',
|
||||
anthropic: 'Anthropic'
|
||||
},
|
||||
settings: {
|
||||
title: '设置',
|
||||
|
||||
142
src/renderer/src/services/ProviderSDK.ts
Normal file
142
src/renderer/src/services/ProviderSDK.ts
Normal file
@ -0,0 +1,142 @@
|
||||
import { Assistant, Message, Provider } from '@renderer/types'
|
||||
import OpenAI from 'openai'
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import { getDefaultModel, getTopNamingModel } from './assistant'
|
||||
import {
|
||||
ChatCompletionCreateParamsNonStreaming,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam
|
||||
} from 'openai/resources'
|
||||
import { sum, takeRight } from 'lodash'
|
||||
import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources'
|
||||
import { EVENT_NAMES } from './event'
|
||||
import { removeQuotes } from '@renderer/utils'
|
||||
|
||||
export default class ProviderSDK {
|
||||
provider: Provider
|
||||
openaiSdk: OpenAI
|
||||
anthropicSdk: Anthropic
|
||||
|
||||
constructor(provider: Provider) {
|
||||
this.provider = provider
|
||||
const host = provider.apiHost
|
||||
const baseURL = host.endsWith('/') ? host : `${provider.apiHost}/v1/`
|
||||
this.anthropicSdk = new Anthropic({ apiKey: provider.apiKey, baseURL })
|
||||
this.openaiSdk = new OpenAI({ dangerouslyAllowBrowser: true, apiKey: provider.apiKey, baseURL })
|
||||
}
|
||||
|
||||
private get isAnthropic() {
|
||||
return this.provider.id === 'anthropic'
|
||||
}
|
||||
|
||||
public async completions(
|
||||
messages: Message[],
|
||||
assistant: Assistant,
|
||||
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
|
||||
) {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
|
||||
const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined
|
||||
|
||||
const userMessages = takeRight(messages, 5).map((message) => ({
|
||||
role: message.role,
|
||||
content: message.content
|
||||
}))
|
||||
|
||||
if (this.isAnthropic) {
|
||||
await this.anthropicSdk.messages
|
||||
.stream({
|
||||
max_tokens: 1024,
|
||||
messages: [systemMessage, ...userMessages].filter(Boolean) as MessageParam[],
|
||||
model: model.id
|
||||
})
|
||||
.on('text', (text) => onChunk({ text: text || '' }))
|
||||
.on('finalMessage', (message) =>
|
||||
onChunk({
|
||||
usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: sum(Object.values(message.usage)) }
|
||||
})
|
||||
)
|
||||
} else {
|
||||
const stream = await this.openaiSdk.chat.completions.create({
|
||||
model: model.id,
|
||||
messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[],
|
||||
stream: true
|
||||
})
|
||||
for await (const chunk of stream) {
|
||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
|
||||
break
|
||||
}
|
||||
onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
|
||||
const model = getTopNamingModel() || assistant.model || getDefaultModel()
|
||||
|
||||
const userMessages: ChatCompletionMessageParam[] = takeRight(messages, 5).map((message) => ({
|
||||
role: 'user',
|
||||
content: message.content
|
||||
}))
|
||||
|
||||
const systemMessage: ChatCompletionSystemMessageParam = {
|
||||
role: 'system',
|
||||
content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要加标点符号'
|
||||
}
|
||||
|
||||
if (this.isAnthropic) {
|
||||
const message = await this.anthropicSdk.messages.create({
|
||||
messages: [systemMessage, ...userMessages] as Anthropic.Messages.MessageParam[],
|
||||
model: model.id,
|
||||
stream: false,
|
||||
max_tokens: 50
|
||||
})
|
||||
|
||||
return message.content[0].type === 'text' ? message.content[0].text : null
|
||||
} else {
|
||||
const response = await this.openaiSdk.chat.completions.create({
|
||||
model: model.id,
|
||||
messages: [systemMessage, ...userMessages],
|
||||
stream: false
|
||||
})
|
||||
|
||||
return removeQuotes(response.choices[0].message?.content || '')
|
||||
}
|
||||
}
|
||||
|
||||
public async check(): Promise<{ valid: boolean; error: Error | null }> {
|
||||
const model = this.provider.models[0]
|
||||
const body = {
|
||||
model: model.id,
|
||||
messages: [{ role: 'user', content: 'hi' }],
|
||||
max_tokens: 100,
|
||||
stream: false
|
||||
}
|
||||
|
||||
try {
|
||||
if (this.isAnthropic) {
|
||||
const message = await this.anthropicSdk.messages.create(body as MessageCreateParamsNonStreaming)
|
||||
return { valid: message.content.length > 0, error: null }
|
||||
} else {
|
||||
const response = await this.openaiSdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming)
|
||||
return { valid: Boolean(response?.choices[0].message), error: null }
|
||||
}
|
||||
} catch (error: any) {
|
||||
return { valid: false, error }
|
||||
}
|
||||
}
|
||||
|
||||
public async models(): Promise<OpenAI.Models.Model[]> {
|
||||
try {
|
||||
if (this.isAnthropic) {
|
||||
return []
|
||||
}
|
||||
|
||||
const response = await this.openaiSdk.models.list()
|
||||
return response.data
|
||||
} catch (error) {
|
||||
return []
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,38 +1,30 @@
|
||||
import { Assistant, Message, Provider, Topic } from '@renderer/types'
|
||||
import { uuid } from '@renderer/utils'
|
||||
import { EVENT_NAMES, EventEmitter } from './event'
|
||||
import { ChatCompletionMessageParam, ChatCompletionSystemMessageParam } from 'openai/resources'
|
||||
import OpenAI from 'openai'
|
||||
import { getAssistantProvider, getDefaultModel, getProviderByModel, getTopNamingModel } from './assistant'
|
||||
import { takeRight } from 'lodash'
|
||||
import dayjs from 'dayjs'
|
||||
import i18n from '@renderer/i18n'
|
||||
import store from '@renderer/store'
|
||||
import { setGenerating } from '@renderer/store/runtime'
|
||||
import i18n from '@renderer/i18n'
|
||||
import { Assistant, Message, Provider, Topic } from '@renderer/types'
|
||||
import { getErrorMessage, uuid } from '@renderer/utils'
|
||||
import dayjs from 'dayjs'
|
||||
import { getAssistantProvider, getDefaultModel, getProviderByModel, getTopNamingModel } from './assistant'
|
||||
import { EVENT_NAMES, EventEmitter } from './event'
|
||||
import ProviderSDK from './ProviderSDK'
|
||||
|
||||
interface FetchChatCompletionParams {
|
||||
export async function fetchChatCompletion({
|
||||
messages,
|
||||
topic,
|
||||
assistant,
|
||||
onResponse
|
||||
}: {
|
||||
messages: Message[]
|
||||
topic: Topic
|
||||
assistant: Assistant
|
||||
onResponse: (message: Message) => void
|
||||
}
|
||||
|
||||
const getOpenAiProvider = (provider: Provider) => {
|
||||
const host = provider.apiHost
|
||||
return new OpenAI({
|
||||
dangerouslyAllowBrowser: true,
|
||||
apiKey: provider.apiKey,
|
||||
baseURL: host.endsWith('/') ? host : `${provider.apiHost}/v1/`
|
||||
})
|
||||
}
|
||||
|
||||
export async function fetchChatCompletion({ messages, topic, assistant, onResponse }: FetchChatCompletionParams) {
|
||||
}) {
|
||||
window.keyv.set(EVENT_NAMES.CHAT_COMPLETION_PAUSED, false)
|
||||
|
||||
const provider = getAssistantProvider(assistant)
|
||||
const openaiProvider = getOpenAiProvider(provider)
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const providerSdk = new ProviderSDK(provider)
|
||||
|
||||
store.dispatch(setGenerating(true))
|
||||
|
||||
@ -49,79 +41,36 @@ export async function fetchChatCompletion({ messages, topic, assistant, onRespon
|
||||
|
||||
onResponse({ ...message })
|
||||
|
||||
const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined
|
||||
|
||||
const userMessages = takeRight(messages, 5).map((message) => ({
|
||||
role: message.role,
|
||||
content: message.content
|
||||
}))
|
||||
|
||||
try {
|
||||
const stream = await openaiProvider.chat.completions.create({
|
||||
model: model.id,
|
||||
messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[],
|
||||
stream: true
|
||||
await providerSdk.completions(messages, assistant, ({ text, usage }) => {
|
||||
message.content = message.content + text || ''
|
||||
message.usage = usage
|
||||
onResponse({ ...message, status: 'pending' })
|
||||
})
|
||||
|
||||
let content = ''
|
||||
let usage: OpenAI.Completions.CompletionUsage | undefined = undefined
|
||||
|
||||
for await (const chunk of stream) {
|
||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
|
||||
break
|
||||
}
|
||||
|
||||
content = content + (chunk.choices[0]?.delta?.content || '')
|
||||
chunk.usage && (usage = chunk.usage)
|
||||
onResponse({ ...message, content, status: 'pending' })
|
||||
}
|
||||
|
||||
message.content = content
|
||||
message.usage = usage
|
||||
} catch (error: any) {
|
||||
message.content = `Error: ${error.message}`
|
||||
}
|
||||
|
||||
const paused = window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)
|
||||
message.status = paused ? 'paused' : 'success'
|
||||
// Update message status
|
||||
message.status = window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED) ? 'paused' : 'success'
|
||||
|
||||
// Emit chat completion event
|
||||
EventEmitter.emit(EVENT_NAMES.AI_CHAT_COMPLETION, message)
|
||||
|
||||
// Reset generating state
|
||||
store.dispatch(setGenerating(false))
|
||||
|
||||
return message
|
||||
}
|
||||
|
||||
interface FetchMessagesSummaryParams {
|
||||
messages: Message[]
|
||||
assistant: Assistant
|
||||
}
|
||||
|
||||
export async function fetchMessagesSummary({ messages, assistant }: FetchMessagesSummaryParams) {
|
||||
export async function fetchMessagesSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
|
||||
const model = getTopNamingModel() || assistant.model || getDefaultModel()
|
||||
const provider = getProviderByModel(model)
|
||||
const openaiProvider = getOpenAiProvider(provider)
|
||||
|
||||
const userMessages: ChatCompletionMessageParam[] = takeRight(messages, 5).map((message) => ({
|
||||
role: 'user',
|
||||
content: message.content
|
||||
}))
|
||||
|
||||
const systemMessage: ChatCompletionSystemMessageParam = {
|
||||
role: 'system',
|
||||
content:
|
||||
'你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,回复内容不需要用引号引起来,不需要在结尾加上句号。'
|
||||
}
|
||||
|
||||
const response = await openaiProvider.chat.completions.create({
|
||||
model: model.id,
|
||||
messages: [systemMessage, ...userMessages],
|
||||
stream: false
|
||||
})
|
||||
|
||||
return response.choices[0].message?.content
|
||||
const providerSdk = new ProviderSDK(provider)
|
||||
return providerSdk.summaries(messages, assistant)
|
||||
}
|
||||
|
||||
export async function checkApi(provider: Provider) {
|
||||
const openaiProvider = getOpenAiProvider(provider)
|
||||
const model = provider.models[0]
|
||||
const key = 'api-check'
|
||||
const style = { marginTop: '3vh' }
|
||||
@ -141,22 +90,9 @@ export async function checkApi(provider: Provider) {
|
||||
return false
|
||||
}
|
||||
|
||||
let valid = false
|
||||
let errorMessage = ''
|
||||
const providerSdk = new ProviderSDK(provider)
|
||||
|
||||
try {
|
||||
const response = await openaiProvider.chat.completions.create({
|
||||
model: model.id,
|
||||
messages: [{ role: 'user', content: 'hi' }],
|
||||
max_tokens: 100,
|
||||
stream: false
|
||||
})
|
||||
|
||||
valid = Boolean(response?.choices[0].message)
|
||||
} catch (error) {
|
||||
errorMessage = (error as Error).message
|
||||
valid = false
|
||||
}
|
||||
const { valid, error } = await providerSdk.check()
|
||||
|
||||
window.message[valid ? 'success' : 'error']({
|
||||
key: 'api-check',
|
||||
@ -164,17 +100,17 @@ export async function checkApi(provider: Provider) {
|
||||
duration: valid ? 2 : 8,
|
||||
content: valid
|
||||
? i18n.t('message.api.connection.success')
|
||||
: i18n.t('message.api.connection.failed') + ' ' + errorMessage
|
||||
: i18n.t('message.api.connection.failed') + ' : ' + getErrorMessage(error)
|
||||
})
|
||||
|
||||
return valid
|
||||
}
|
||||
|
||||
export async function fetchModels(provider: Provider) {
|
||||
const providerSdk = new ProviderSDK(provider)
|
||||
|
||||
try {
|
||||
const openaiProvider = getOpenAiProvider(provider)
|
||||
const response = await openaiProvider.models.list()
|
||||
return response.data
|
||||
return await providerSdk.models()
|
||||
} catch (error) {
|
||||
return []
|
||||
}
|
||||
|
||||
@ -85,6 +85,15 @@ const initialState: LlmState = {
|
||||
isSystem: true,
|
||||
enabled: false
|
||||
},
|
||||
{
|
||||
id: 'anthropic',
|
||||
name: 'Anthropic',
|
||||
apiKey: '',
|
||||
apiHost: 'https://api.anthropic.com/',
|
||||
models: SYSTEM_MODELS.anthropic.filter((m) => m.enabled),
|
||||
isSystem: true,
|
||||
enabled: false
|
||||
},
|
||||
{
|
||||
id: 'openrouter',
|
||||
name: 'OpenRouter',
|
||||
|
||||
@ -194,6 +194,15 @@ const migrate = createMigrate({
|
||||
models: SYSTEM_MODELS.dashscope.filter((m) => m.enabled),
|
||||
isSystem: true,
|
||||
enabled: false
|
||||
},
|
||||
{
|
||||
id: 'anthropic',
|
||||
name: 'Anthropic',
|
||||
apiKey: '',
|
||||
apiHost: 'https://api.anthropic.com/',
|
||||
models: SYSTEM_MODELS.anthropic.filter((m) => m.enabled),
|
||||
isSystem: true,
|
||||
enabled: false
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -108,3 +108,27 @@ export async function isDev() {
|
||||
const isProd = await isProduction()
|
||||
return !isProd
|
||||
}
|
||||
|
||||
export function getErrorMessage(error: any) {
|
||||
if (!error) {
|
||||
return ''
|
||||
}
|
||||
|
||||
if (typeof error === 'string') {
|
||||
return error
|
||||
}
|
||||
|
||||
if (error?.error) {
|
||||
return getErrorMessage(error.error)
|
||||
}
|
||||
|
||||
if (error?.message) {
|
||||
return error.message
|
||||
}
|
||||
|
||||
return ''
|
||||
}
|
||||
|
||||
export function removeQuotes(str) {
|
||||
return str.replace(/['"]+/g, '')
|
||||
}
|
||||
|
||||
17
yarn.lock
17
yarn.lock
@ -87,6 +87,22 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@anthropic-ai/sdk@npm:^0.24.3":
|
||||
version: 0.24.3
|
||||
resolution: "@anthropic-ai/sdk@npm:0.24.3"
|
||||
dependencies:
|
||||
"@types/node": "npm:^18.11.18"
|
||||
"@types/node-fetch": "npm:^2.6.4"
|
||||
abort-controller: "npm:^3.0.0"
|
||||
agentkeepalive: "npm:^4.2.1"
|
||||
form-data-encoder: "npm:1.7.2"
|
||||
formdata-node: "npm:^4.3.2"
|
||||
node-fetch: "npm:^2.6.7"
|
||||
web-streams-polyfill: "npm:^3.2.1"
|
||||
checksum: 10c0/1c73c3df9637522da548d2cddfaf89513dac935c5cdb7c0b3db1c427c069a0de76df935bd189e477822063e9f944360e2d059827d5be4dca33bd388c61e97a30
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@babel/code-frame@npm:^7.23.5, @babel/code-frame@npm:^7.24.2":
|
||||
version: 7.24.2
|
||||
resolution: "@babel/code-frame@npm:7.24.2"
|
||||
@ -3391,6 +3407,7 @@ __metadata:
|
||||
version: 0.0.0-use.local
|
||||
resolution: "cherry-studio@workspace:."
|
||||
dependencies:
|
||||
"@anthropic-ai/sdk": "npm:^0.24.3"
|
||||
"@electron-toolkit/eslint-config-prettier": "npm:^2.0.0"
|
||||
"@electron-toolkit/eslint-config-ts": "npm:^1.0.1"
|
||||
"@electron-toolkit/preload": "npm:^3.0.0"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user