2024-08-28 18:11:35 +08:00

345 lines
11 KiB
TypeScript

import Anthropic from '@anthropic-ai/sdk'
import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources'
import { GoogleGenerativeAI } from '@google/generative-ai'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { isLocalAi } from '@renderer/config/env'
import { getOllamaKeepAliveTime } from '@renderer/hooks/useOllama'
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
import { removeQuotes } from '@renderer/utils'
import axios from 'axios'
import { first, isEmpty, sum, takeRight } from 'lodash'
import OpenAI from 'openai'
import { ChatCompletionCreateParamsNonStreaming, ChatCompletionMessageParam } from 'openai/resources'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from './assistant'
import { EVENT_NAMES } from './event'
export default class ProviderSDK {
provider: Provider
openaiSdk: OpenAI
anthropicSdk: Anthropic
geminiSdk: GoogleGenerativeAI
constructor(provider: Provider) {
this.provider = provider
const host = provider.apiHost
const baseURL = host.endsWith('/') ? host : `${provider.apiHost}/v1/`
this.anthropicSdk = new Anthropic({ apiKey: provider.apiKey, baseURL })
this.openaiSdk = new OpenAI({ dangerouslyAllowBrowser: true, apiKey: provider.apiKey, baseURL })
this.geminiSdk = new GoogleGenerativeAI(provider.apiKey)
}
private get isAnthropic() {
return this.provider.id === 'anthropic'
}
private get isGemini() {
return this.provider.id === 'gemini'
}
private get keepAliveTime() {
return this.provider.id === 'ollama' ? getOllamaKeepAliveTime() : undefined
}
public async completions(
messages: Message[],
assistant: Assistant,
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
) {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens } = getAssistantSettings(assistant)
const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined
const userMessages = takeRight(messages, contextCount + 1).map((message) => ({
role: message.role,
content: message.content
}))
if (this.isAnthropic) {
return new Promise<void>((resolve, reject) => {
const stream = this.anthropicSdk.messages
.stream({
model: model.id,
messages: userMessages.filter(Boolean) as MessageParam[],
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
temperature: assistant?.settings?.temperature,
system: assistant.prompt,
stream: true
})
.on('text', (text) => {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
resolve()
return stream.controller.abort()
}
onChunk({ text })
})
.on('finalMessage', (message) => {
onChunk({
text: '',
usage: {
prompt_tokens: message.usage.input_tokens,
completion_tokens: message.usage.output_tokens,
total_tokens: sum(Object.values(message.usage))
}
})
resolve()
})
.on('error', (error) => reject(error))
})
}
if (this.isGemini) {
const geminiModel = this.geminiSdk.getGenerativeModel({
model: model.id,
systemInstruction: assistant.prompt,
generationConfig: {
maxOutputTokens: maxTokens,
temperature: assistant?.settings?.temperature
}
})
const userLastMessage = userMessages.pop()
const chat = geminiModel.startChat({
history: userMessages.map((message) => ({
role: message.role === 'user' ? 'user' : 'model',
parts: [{ text: message.content }]
}))
})
const userMessagesStream = await chat.sendMessageStream(userLastMessage?.content!)
for await (const chunk of userMessagesStream.stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
onChunk({
text: chunk.text(),
usage: {
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0,
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
}
})
}
return
}
// @ts-ignore key is not typed
const stream = await this.openaiSdk.chat.completions.create({
model: model.id,
messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[],
stream: true,
temperature: assistant?.settings?.temperature,
max_tokens: maxTokens,
keep_alive: this.keepAliveTime
})
for await (const chunk of stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage })
}
}
public async translate(message: Message, assistant: Assistant) {
const defaultModel = getDefaultModel()
const { maxTokens } = getAssistantSettings(assistant)
const model = assistant.model || defaultModel
const messages = [
{ role: 'system', content: assistant.prompt },
{ role: 'user', content: message.content }
]
if (this.isAnthropic) {
const response = await this.anthropicSdk.messages.create({
model: model.id,
messages: messages.filter((m) => m.role === 'user') as MessageParam[],
max_tokens: 4096,
temperature: assistant?.settings?.temperature,
system: assistant.prompt,
stream: false
})
return response.content[0].type === 'text' ? response.content[0].text : ''
}
if (this.isGemini) {
const geminiModel = this.geminiSdk.getGenerativeModel({
model: model.id,
systemInstruction: assistant.prompt,
generationConfig: {
maxOutputTokens: maxTokens,
temperature: assistant?.settings?.temperature
}
})
const { response } = await geminiModel.generateContent(message.content)
return response.text()
}
// @ts-ignore key is not typed
const response = await this.openaiSdk.chat.completions.create({
model: model.id,
messages: messages as ChatCompletionMessageParam[],
stream: false,
keep_alive: this.keepAliveTime
})
return response.choices[0].message?.content || ''
}
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
const model = getTopNamingModel() || assistant.model || getDefaultModel()
const userMessages = takeRight(messages, 5).map((message) => ({
role: message.role,
content: message.content
}))
const systemMessage = {
role: 'system',
content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要使用标点符号和其他特殊符号。'
}
if (this.isAnthropic) {
const message = await this.anthropicSdk.messages.create({
messages: userMessages as Anthropic.Messages.MessageParam[],
model: model.id,
system: systemMessage.content,
stream: false,
max_tokens: 4096
})
return message.content[0].type === 'text' ? message.content[0].text : null
}
if (this.isGemini) {
const geminiModel = this.geminiSdk.getGenerativeModel({
model: model.id,
systemInstruction: systemMessage.content,
generationConfig: {
temperature: assistant?.settings?.temperature
}
})
const lastUserMessage = userMessages.pop()
const chat = await geminiModel.startChat({
history: userMessages.map((message) => ({
role: message.role === 'user' ? 'user' : 'model',
parts: [{ text: message.content }]
}))
})
const { response } = await chat.sendMessage(lastUserMessage?.content!)
return response.text()
}
// @ts-ignore key is not typed
const response = await this.openaiSdk.chat.completions.create({
model: model.id,
messages: [systemMessage, ...(isLocalAi ? [first(userMessages)] : userMessages)] as ChatCompletionMessageParam[],
stream: false,
max_tokens: 50,
keep_alive: this.keepAliveTime
})
return removeQuotes(response.choices[0].message?.content?.substring(0, 50) || '')
}
public async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> {
const model = assistant.model
if (!model) {
return []
}
const response: any = await this.openaiSdk.request({
method: 'post',
path: '/advice_questions',
body: {
messages: messages.filter((m) => m.role === 'user').map((m) => ({ role: m.role, content: m.content })),
model: model.id,
max_tokens: 0,
temperature: 0,
n: 0
}
})
return response?.questions?.filter(Boolean)?.map((q: any) => ({ content: q })) || []
}
public async check(): Promise<{ valid: boolean; error: Error | null }> {
const model = this.provider.models[0]
const body = {
model: model.id,
messages: [{ role: 'user', content: 'hi' }],
max_tokens: 100,
stream: false
}
try {
if (this.isAnthropic) {
const message = await this.anthropicSdk.messages.create(body as MessageCreateParamsNonStreaming)
return {
valid: message.content.length > 0,
error: null
}
}
if (this.isGemini) {
const geminiModel = this.geminiSdk.getGenerativeModel({ model: body.model })
const result = await geminiModel.generateContent(body.messages[0].content)
return {
valid: !isEmpty(result.response.text()),
error: null
}
}
const response = await this.openaiSdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming)
return {
valid: Boolean(response?.choices[0].message),
error: null
}
} catch (error: any) {
return {
valid: false,
error
}
}
}
public async models(): Promise<OpenAI.Models.Model[]> {
try {
if (this.isAnthropic) {
return []
}
if (this.isGemini) {
const api = this.provider.apiHost + '/v1beta/models'
const { data } = await axios.get(api, { params: { key: this.provider.apiKey } })
return data.models.map(
(m: any) =>
({
id: m.name.replace('models/', ''),
name: m.displayName,
description: m.description,
object: 'model',
created: Date.now(),
owned_by: 'gemini'
}) as OpenAI.Models.Model
)
}
const response = await this.openaiSdk.models.list()
return response.data
} catch (error) {
return []
}
}
}