feat: paintings add prompt enhancement params

This commit is contained in:
kangfenmao 2025-01-02 14:51:52 +08:00
parent 6384525e20
commit 038aa2d5cc
12 changed files with 109 additions and 68 deletions

View File

@ -961,6 +961,11 @@ export const TEXT_TO_IMAGES_MODELS = [
}
]
export const TEXT_TO_IMAGES_MODELS_SUPPORT_IMAGE_ENHANCEMENT = [
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-xl-base-1.0'
]
export function isTextToImageModel(model: Model): boolean {
return TEXT_TO_IMAGE_REGEX.test(model.id)
}

View File

@ -14,6 +14,7 @@ export function usePaintings() {
paintings,
addPainting: () => {
const newPainting: Painting = {
model: TEXT_TO_IMAGES_MODELS[0].id,
id: uuid(),
urls: [],
files: [],
@ -24,7 +25,7 @@ export function usePaintings() {
seed: generateRandomSeed(),
steps: 25,
guidanceScale: 4.5,
model: TEXT_TO_IMAGES_MODELS[0].id
promptEnhancement: true
}
dispatch(addPainting(newPainting))
return newPainting

View File

@ -279,7 +279,9 @@
"regenerate.confirm": "This will replace your existing generated images. Do you want to continue?",
"seed": "Seed",
"seed_tip": "The same seed and prompt can produce similar images",
"title": "Images"
"title": "Images",
"prompt_enhancement": "Prompt Enhancement",
"prompt_enhancement_tip": "Rewrite prompts into detailed, model-friendly versions when switched on"
},
"provider": {
"aihubmix": "AiHubMix",

View File

@ -277,7 +277,9 @@
"regenerate.confirm": "これにより、既存の生成画像が置き換えられます。続行しますか?",
"seed": "シード",
"seed_tip": "同じシードとプロンプトで似た画像を生成できます",
"title": "画像"
"title": "画像",
"prompt_enhancement": "プロンプト強化",
"prompt_enhancement_tip": "オンにすると、プロンプトを詳細でモデルに適したバージョンに書き直します"
},
"provider": {
"aihubmix": "AiHubMix",

View File

@ -279,7 +279,9 @@
"regenerate.confirm": "Это заменит ваши существующие сгенерированные изображения. Хотите продолжить?",
"seed": "Ключ генерации",
"seed_tip": "Одинаковый ключ генерации и промпт могут производить похожие изображения",
"title": "Изображения"
"title": "Изображения",
"prompt_enhancement": "Улучшение промпта",
"prompt_enhancement_tip": "При включении переписывает промпт в более детальную, модель-ориентированную версию"
},
"provider": {
"aihubmix": "AiHubMix",

View File

@ -280,7 +280,9 @@
"regenerate.confirm": "这将覆盖已生成的图片,是否继续?",
"seed": "随机种子",
"seed_tip": "相同的种子和提示词可以生成相似的图片",
"title": "图片"
"title": "图片",
"prompt_enhancement": "提示词增强",
"prompt_enhancement_tip": "开启后将提示重写为详细的、适合模型的版本"
},
"provider": {
"aihubmix": "AiHubMix",

View File

@ -279,7 +279,9 @@
"regenerate.confirm": "這將覆蓋已生成的圖片,是否繼續?",
"seed": "隨機種子",
"seed_tip": "相同的種子和提示詞可以生成相似的圖片",
"title": "繪圖"
"title": "繪圖",
"prompt_enhancement": "提示詞增強",
"prompt_enhancement_tip": "開啟後將提示重寫為詳細的、適合模型的版本"
},
"provider": {
"aihubmix": "AiHubMix",

View File

@ -6,7 +6,7 @@ import ImageSize3_4 from '@renderer/assets/images/paintings/image-size-3-4.svg'
import ImageSize9_16 from '@renderer/assets/images/paintings/image-size-9-16.svg'
import ImageSize16_9 from '@renderer/assets/images/paintings/image-size-16-9.svg'
import { Navbar, NavbarCenter, NavbarRight } from '@renderer/components/app/Navbar'
import { VStack } from '@renderer/components/Layout'
import { HStack, VStack } from '@renderer/components/Layout'
import Scrollbar from '@renderer/components/Scrollbar'
import TranslateButton from '@renderer/components/TranslateButton'
import { isMac } from '@renderer/config/constant'
@ -25,7 +25,7 @@ import { DEFAULT_PAINTING } from '@renderer/store/paintings'
import { setGenerating } from '@renderer/store/runtime'
import { FileType, Painting } from '@renderer/types'
import { getErrorMessage } from '@renderer/utils'
import { Button, Input, InputNumber, Radio, Select, Slider, Tooltip } from 'antd'
import { Button, Input, InputNumber, Radio, Select, Slider, Switch, Tooltip } from 'antd'
import TextArea from 'antd/es/input/TextArea'
import { FC, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
@ -149,8 +149,13 @@ const PaintingsPage: FC = () => {
dispatch(setGenerating(true))
const AI = new AiProvider(provider)
if (!painting.model) {
return
}
try {
const urls = await AI.generateImage({
model: painting.model,
prompt,
negativePrompt: painting.negativePrompt || '',
imageSize: painting.imageSize || '1024x1024',
@ -158,7 +163,8 @@ const PaintingsPage: FC = () => {
seed: painting.seed || undefined,
numInferenceSteps: painting.steps || 25,
guidanceScale: painting.guidanceScale || 4.5,
signal: controller.signal
signal: controller.signal,
promptEnhancement: painting.promptEnhancement || false
})
if (urls.length > 0) {
@ -360,13 +366,15 @@ const PaintingsPage: FC = () => {
<InfoIcon />
</Tooltip>
</SettingTitle>
<SliderContainer>
<Slider min={1} max={50} value={painting.steps} onChange={(v) => updatePaintingState({ steps: v })} />
<InputNumber
<StyledInputNumber
min={1}
max={50}
value={painting.steps}
onChange={(v) => updatePaintingState({ steps: v || 25 })}
onChange={(v) => updatePaintingState({ steps: (v as number) || 25 })}
/>
</SliderContainer>
<SettingTitle style={{ marginBottom: 5, marginTop: 15 }}>
{t('paintings.guidance_scale')}
@ -374,6 +382,7 @@ const PaintingsPage: FC = () => {
<InfoIcon />
</Tooltip>
</SettingTitle>
<SliderContainer>
<Slider
min={1}
max={20}
@ -381,14 +390,14 @@ const PaintingsPage: FC = () => {
value={painting.guidanceScale}
onChange={(v) => updatePaintingState({ guidanceScale: v })}
/>
<InputNumber
<StyledInputNumber
min={1}
max={20}
step={0.1}
value={painting.guidanceScale}
onChange={(v) => updatePaintingState({ guidanceScale: v || 4.5 })}
onChange={(v) => updatePaintingState({ guidanceScale: (v as number) || 4.5 })}
/>
</SliderContainer>
<SettingTitle style={{ marginBottom: 5, marginTop: 15 }}>
{t('paintings.negative_prompt')}
<Tooltip title={t('paintings.negative_prompt_tip')}>
@ -400,6 +409,18 @@ const PaintingsPage: FC = () => {
onChange={(e) => updatePaintingState({ negativePrompt: e.target.value })}
rows={4}
/>
<SettingTitle style={{ marginBottom: 5, marginTop: 15 }}>
{t('paintings.prompt_enhancement')}
<Tooltip title={t('paintings.prompt_enhancement_tip')}>
<InfoIcon />
</Tooltip>
</SettingTitle>
<HStack>
<Switch
checked={painting.promptEnhancement}
onChange={(checked) => updatePaintingState({ promptEnhancement: checked })}
/>
</HStack>
</LeftContainer>
<MainContainer>
<Artboard
@ -547,4 +568,18 @@ const InfoIcon = styled(QuestionCircleOutlined)`
}
`
const SliderContainer = styled.div`
display: flex;
align-items: center;
gap: 16px;
.ant-slider {
flex: 1;
}
`
const StyledInputNumber = styled(InputNumber)`
width: 70px;
`
export default PaintingsPage

View File

@ -1,6 +1,6 @@
import BaseProvider from '@renderer/providers/BaseProvider'
import ProviderFactory from '@renderer/providers/ProviderFactory'
import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types'
import { Assistant, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
import OpenAI from 'openai'
import { CompletionsParams } from '.'
@ -48,16 +48,7 @@ export default class AiProvider {
return this.sdk.getApiKey()
}
public async generateImage(params: {
prompt: string
negativePrompt: string
imageSize: string
batchSize: number
seed?: string
numInferenceSteps: number
guidanceScale: number
signal?: AbortSignal
}): Promise<string[]> {
public async generateImage(params: GenerateImageParams): Promise<string[]> {
return this.sdk.generateImage(params)
}

View File

@ -2,7 +2,7 @@ import { REFERENCE_PROMPT } from '@renderer/config/prompts'
import { getOllamaKeepAliveTime } from '@renderer/hooks/useOllama'
import { getKnowledgeReferences } from '@renderer/services/KnowledgeService'
import store from '@renderer/store'
import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types'
import { Assistant, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
import { delay, isJSON } from '@renderer/utils'
import OpenAI from 'openai'
@ -26,16 +26,7 @@ export default abstract class BaseProvider {
abstract generateText({ prompt, content }: { prompt: string; content: string }): Promise<string>
abstract check(): Promise<{ valid: boolean; error: Error | null }>
abstract models(): Promise<OpenAI.Models.Model[]>
abstract generateImage(_params: {
prompt: string
negativePrompt: string
imageSize: string
batchSize: number
seed?: string
numInferenceSteps: number
guidanceScale: number
signal?: AbortSignal
}): Promise<string[]>
abstract generateImage(params: GenerateImageParams): Promise<string[]>
abstract getEmbeddingDimensions(model: Model): Promise<number>
public getBaseURL(): string {

View File

@ -4,7 +4,7 @@ import i18n from '@renderer/i18n'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
import { EVENT_NAMES } from '@renderer/services/EventService'
import { filterContextMessages } from '@renderer/services/MessagesService'
import { Assistant, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types'
import { Assistant, FileTypes, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
import { removeQuotes } from '@renderer/utils'
import { last, takeRight } from 'lodash'
import OpenAI, { AzureOpenAI } from 'openai'
@ -345,6 +345,7 @@ export default class OpenAIProvider extends BaseProvider {
}
public async generateImage({
model,
prompt,
negativePrompt,
imageSize,
@ -352,30 +353,23 @@ export default class OpenAIProvider extends BaseProvider {
seed,
numInferenceSteps,
guidanceScale,
signal
}: {
prompt: string
negativePrompt?: string
imageSize: string
batchSize: number
seed?: string
numInferenceSteps: number
guidanceScale: number
signal?: AbortSignal
}): Promise<string[]> {
signal,
promptEnhancement
}: GenerateImageParams): Promise<string[]> {
const response = (await this.sdk.request({
method: 'post',
path: '/images/generations',
signal,
body: {
model: 'stabilityai/stable-diffusion-3-5-large',
model,
prompt,
negative_prompt: negativePrompt,
image_size: imageSize,
batch_size: batchSize,
seed: seed ? parseInt(seed) : undefined,
num_inference_steps: numInferenceSteps,
guidance_scale: guidanceScale
guidance_scale: guidanceScale,
prompt_enhancement: promptEnhancement
}
})) as { data: Array<{ url: string }> }

View File

@ -112,6 +112,7 @@ export type Suggestion = {
export interface Painting {
id: string
model?: string
urls: string[]
files: FileType[]
prompt?: string
@ -121,7 +122,7 @@ export interface Painting {
seed?: string
steps?: number
guidanceScale?: number
model?: string
promptEnhancement?: boolean
}
export type MinAppType = {
@ -224,3 +225,16 @@ export type KnowledgeBaseParams = {
apiVersion?: string
baseURL: string
}
export type GenerateImageParams = {
model: string
prompt: string
negativePrompt?: string
imageSize: string
batchSize: number
seed?: string
numInferenceSteps: number
guidanceScale: number
signal?: AbortSignal
promptEnhancement?: boolean
}