feat: paintings add prompt enhancement params

This commit is contained in:
kangfenmao 2025-01-02 14:51:52 +08:00
parent 6384525e20
commit 038aa2d5cc
12 changed files with 109 additions and 68 deletions

View File

@ -961,6 +961,11 @@ export const TEXT_TO_IMAGES_MODELS = [
} }
] ]
export const TEXT_TO_IMAGES_MODELS_SUPPORT_IMAGE_ENHANCEMENT = [
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-xl-base-1.0'
]
export function isTextToImageModel(model: Model): boolean { export function isTextToImageModel(model: Model): boolean {
return TEXT_TO_IMAGE_REGEX.test(model.id) return TEXT_TO_IMAGE_REGEX.test(model.id)
} }

View File

@ -14,6 +14,7 @@ export function usePaintings() {
paintings, paintings,
addPainting: () => { addPainting: () => {
const newPainting: Painting = { const newPainting: Painting = {
model: TEXT_TO_IMAGES_MODELS[0].id,
id: uuid(), id: uuid(),
urls: [], urls: [],
files: [], files: [],
@ -24,7 +25,7 @@ export function usePaintings() {
seed: generateRandomSeed(), seed: generateRandomSeed(),
steps: 25, steps: 25,
guidanceScale: 4.5, guidanceScale: 4.5,
model: TEXT_TO_IMAGES_MODELS[0].id promptEnhancement: true
} }
dispatch(addPainting(newPainting)) dispatch(addPainting(newPainting))
return newPainting return newPainting

View File

@ -279,7 +279,9 @@
"regenerate.confirm": "This will replace your existing generated images. Do you want to continue?", "regenerate.confirm": "This will replace your existing generated images. Do you want to continue?",
"seed": "Seed", "seed": "Seed",
"seed_tip": "The same seed and prompt can produce similar images", "seed_tip": "The same seed and prompt can produce similar images",
"title": "Images" "title": "Images",
"prompt_enhancement": "Prompt Enhancement",
"prompt_enhancement_tip": "Rewrite prompts into detailed, model-friendly versions when switched on"
}, },
"provider": { "provider": {
"aihubmix": "AiHubMix", "aihubmix": "AiHubMix",

View File

@ -277,7 +277,9 @@
"regenerate.confirm": "これにより、既存の生成画像が置き換えられます。続行しますか?", "regenerate.confirm": "これにより、既存の生成画像が置き換えられます。続行しますか?",
"seed": "シード", "seed": "シード",
"seed_tip": "同じシードとプロンプトで似た画像を生成できます", "seed_tip": "同じシードとプロンプトで似た画像を生成できます",
"title": "画像" "title": "画像",
"prompt_enhancement": "プロンプト強化",
"prompt_enhancement_tip": "オンにすると、プロンプトを詳細でモデルに適したバージョンに書き直します"
}, },
"provider": { "provider": {
"aihubmix": "AiHubMix", "aihubmix": "AiHubMix",

View File

@ -279,7 +279,9 @@
"regenerate.confirm": "Это заменит ваши существующие сгенерированные изображения. Хотите продолжить?", "regenerate.confirm": "Это заменит ваши существующие сгенерированные изображения. Хотите продолжить?",
"seed": "Ключ генерации", "seed": "Ключ генерации",
"seed_tip": "Одинаковый ключ генерации и промпт могут производить похожие изображения", "seed_tip": "Одинаковый ключ генерации и промпт могут производить похожие изображения",
"title": "Изображения" "title": "Изображения",
"prompt_enhancement": "Улучшение промпта",
"prompt_enhancement_tip": "При включении переписывает промпт в более детальную, модель-ориентированную версию"
}, },
"provider": { "provider": {
"aihubmix": "AiHubMix", "aihubmix": "AiHubMix",

View File

@ -280,7 +280,9 @@
"regenerate.confirm": "这将覆盖已生成的图片,是否继续?", "regenerate.confirm": "这将覆盖已生成的图片,是否继续?",
"seed": "随机种子", "seed": "随机种子",
"seed_tip": "相同的种子和提示词可以生成相似的图片", "seed_tip": "相同的种子和提示词可以生成相似的图片",
"title": "图片" "title": "图片",
"prompt_enhancement": "提示词增强",
"prompt_enhancement_tip": "开启后将提示重写为详细的、适合模型的版本"
}, },
"provider": { "provider": {
"aihubmix": "AiHubMix", "aihubmix": "AiHubMix",

View File

@ -279,7 +279,9 @@
"regenerate.confirm": "這將覆蓋已生成的圖片,是否繼續?", "regenerate.confirm": "這將覆蓋已生成的圖片,是否繼續?",
"seed": "隨機種子", "seed": "隨機種子",
"seed_tip": "相同的種子和提示詞可以生成相似的圖片", "seed_tip": "相同的種子和提示詞可以生成相似的圖片",
"title": "繪圖" "title": "繪圖",
"prompt_enhancement": "提示詞增強",
"prompt_enhancement_tip": "開啟後將提示重寫為詳細的、適合模型的版本"
}, },
"provider": { "provider": {
"aihubmix": "AiHubMix", "aihubmix": "AiHubMix",

View File

@ -6,7 +6,7 @@ import ImageSize3_4 from '@renderer/assets/images/paintings/image-size-3-4.svg'
import ImageSize9_16 from '@renderer/assets/images/paintings/image-size-9-16.svg' import ImageSize9_16 from '@renderer/assets/images/paintings/image-size-9-16.svg'
import ImageSize16_9 from '@renderer/assets/images/paintings/image-size-16-9.svg' import ImageSize16_9 from '@renderer/assets/images/paintings/image-size-16-9.svg'
import { Navbar, NavbarCenter, NavbarRight } from '@renderer/components/app/Navbar' import { Navbar, NavbarCenter, NavbarRight } from '@renderer/components/app/Navbar'
import { VStack } from '@renderer/components/Layout' import { HStack, VStack } from '@renderer/components/Layout'
import Scrollbar from '@renderer/components/Scrollbar' import Scrollbar from '@renderer/components/Scrollbar'
import TranslateButton from '@renderer/components/TranslateButton' import TranslateButton from '@renderer/components/TranslateButton'
import { isMac } from '@renderer/config/constant' import { isMac } from '@renderer/config/constant'
@ -25,7 +25,7 @@ import { DEFAULT_PAINTING } from '@renderer/store/paintings'
import { setGenerating } from '@renderer/store/runtime' import { setGenerating } from '@renderer/store/runtime'
import { FileType, Painting } from '@renderer/types' import { FileType, Painting } from '@renderer/types'
import { getErrorMessage } from '@renderer/utils' import { getErrorMessage } from '@renderer/utils'
import { Button, Input, InputNumber, Radio, Select, Slider, Tooltip } from 'antd' import { Button, Input, InputNumber, Radio, Select, Slider, Switch, Tooltip } from 'antd'
import TextArea from 'antd/es/input/TextArea' import TextArea from 'antd/es/input/TextArea'
import { FC, useEffect, useRef, useState } from 'react' import { FC, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
@ -149,8 +149,13 @@ const PaintingsPage: FC = () => {
dispatch(setGenerating(true)) dispatch(setGenerating(true))
const AI = new AiProvider(provider) const AI = new AiProvider(provider)
if (!painting.model) {
return
}
try { try {
const urls = await AI.generateImage({ const urls = await AI.generateImage({
model: painting.model,
prompt, prompt,
negativePrompt: painting.negativePrompt || '', negativePrompt: painting.negativePrompt || '',
imageSize: painting.imageSize || '1024x1024', imageSize: painting.imageSize || '1024x1024',
@ -158,7 +163,8 @@ const PaintingsPage: FC = () => {
seed: painting.seed || undefined, seed: painting.seed || undefined,
numInferenceSteps: painting.steps || 25, numInferenceSteps: painting.steps || 25,
guidanceScale: painting.guidanceScale || 4.5, guidanceScale: painting.guidanceScale || 4.5,
signal: controller.signal signal: controller.signal,
promptEnhancement: painting.promptEnhancement || false
}) })
if (urls.length > 0) { if (urls.length > 0) {
@ -360,13 +366,15 @@ const PaintingsPage: FC = () => {
<InfoIcon /> <InfoIcon />
</Tooltip> </Tooltip>
</SettingTitle> </SettingTitle>
<Slider min={1} max={50} value={painting.steps} onChange={(v) => updatePaintingState({ steps: v })} /> <SliderContainer>
<InputNumber <Slider min={1} max={50} value={painting.steps} onChange={(v) => updatePaintingState({ steps: v })} />
min={1} <StyledInputNumber
max={50} min={1}
value={painting.steps} max={50}
onChange={(v) => updatePaintingState({ steps: v || 25 })} value={painting.steps}
/> onChange={(v) => updatePaintingState({ steps: (v as number) || 25 })}
/>
</SliderContainer>
<SettingTitle style={{ marginBottom: 5, marginTop: 15 }}> <SettingTitle style={{ marginBottom: 5, marginTop: 15 }}>
{t('paintings.guidance_scale')} {t('paintings.guidance_scale')}
@ -374,21 +382,22 @@ const PaintingsPage: FC = () => {
<InfoIcon /> <InfoIcon />
</Tooltip> </Tooltip>
</SettingTitle> </SettingTitle>
<Slider <SliderContainer>
min={1} <Slider
max={20} min={1}
step={0.1} max={20}
value={painting.guidanceScale} step={0.1}
onChange={(v) => updatePaintingState({ guidanceScale: v })} value={painting.guidanceScale}
/> onChange={(v) => updatePaintingState({ guidanceScale: v })}
<InputNumber />
min={1} <StyledInputNumber
max={20} min={1}
step={0.1} max={20}
value={painting.guidanceScale} step={0.1}
onChange={(v) => updatePaintingState({ guidanceScale: v || 4.5 })} value={painting.guidanceScale}
/> onChange={(v) => updatePaintingState({ guidanceScale: (v as number) || 4.5 })}
/>
</SliderContainer>
<SettingTitle style={{ marginBottom: 5, marginTop: 15 }}> <SettingTitle style={{ marginBottom: 5, marginTop: 15 }}>
{t('paintings.negative_prompt')} {t('paintings.negative_prompt')}
<Tooltip title={t('paintings.negative_prompt_tip')}> <Tooltip title={t('paintings.negative_prompt_tip')}>
@ -400,6 +409,18 @@ const PaintingsPage: FC = () => {
onChange={(e) => updatePaintingState({ negativePrompt: e.target.value })} onChange={(e) => updatePaintingState({ negativePrompt: e.target.value })}
rows={4} rows={4}
/> />
<SettingTitle style={{ marginBottom: 5, marginTop: 15 }}>
{t('paintings.prompt_enhancement')}
<Tooltip title={t('paintings.prompt_enhancement_tip')}>
<InfoIcon />
</Tooltip>
</SettingTitle>
<HStack>
<Switch
checked={painting.promptEnhancement}
onChange={(checked) => updatePaintingState({ promptEnhancement: checked })}
/>
</HStack>
</LeftContainer> </LeftContainer>
<MainContainer> <MainContainer>
<Artboard <Artboard
@ -547,4 +568,18 @@ const InfoIcon = styled(QuestionCircleOutlined)`
} }
` `
const SliderContainer = styled.div`
display: flex;
align-items: center;
gap: 16px;
.ant-slider {
flex: 1;
}
`
const StyledInputNumber = styled(InputNumber)`
width: 70px;
`
export default PaintingsPage export default PaintingsPage

View File

@ -1,6 +1,6 @@
import BaseProvider from '@renderer/providers/BaseProvider' import BaseProvider from '@renderer/providers/BaseProvider'
import ProviderFactory from '@renderer/providers/ProviderFactory' import ProviderFactory from '@renderer/providers/ProviderFactory'
import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types' import { Assistant, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
import OpenAI from 'openai' import OpenAI from 'openai'
import { CompletionsParams } from '.' import { CompletionsParams } from '.'
@ -48,16 +48,7 @@ export default class AiProvider {
return this.sdk.getApiKey() return this.sdk.getApiKey()
} }
public async generateImage(params: { public async generateImage(params: GenerateImageParams): Promise<string[]> {
prompt: string
negativePrompt: string
imageSize: string
batchSize: number
seed?: string
numInferenceSteps: number
guidanceScale: number
signal?: AbortSignal
}): Promise<string[]> {
return this.sdk.generateImage(params) return this.sdk.generateImage(params)
} }

View File

@ -2,7 +2,7 @@ import { REFERENCE_PROMPT } from '@renderer/config/prompts'
import { getOllamaKeepAliveTime } from '@renderer/hooks/useOllama' import { getOllamaKeepAliveTime } from '@renderer/hooks/useOllama'
import { getKnowledgeReferences } from '@renderer/services/KnowledgeService' import { getKnowledgeReferences } from '@renderer/services/KnowledgeService'
import store from '@renderer/store' import store from '@renderer/store'
import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types' import { Assistant, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
import { delay, isJSON } from '@renderer/utils' import { delay, isJSON } from '@renderer/utils'
import OpenAI from 'openai' import OpenAI from 'openai'
@ -26,16 +26,7 @@ export default abstract class BaseProvider {
abstract generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> abstract generateText({ prompt, content }: { prompt: string; content: string }): Promise<string>
abstract check(): Promise<{ valid: boolean; error: Error | null }> abstract check(): Promise<{ valid: boolean; error: Error | null }>
abstract models(): Promise<OpenAI.Models.Model[]> abstract models(): Promise<OpenAI.Models.Model[]>
abstract generateImage(_params: { abstract generateImage(params: GenerateImageParams): Promise<string[]>
prompt: string
negativePrompt: string
imageSize: string
batchSize: number
seed?: string
numInferenceSteps: number
guidanceScale: number
signal?: AbortSignal
}): Promise<string[]>
abstract getEmbeddingDimensions(model: Model): Promise<number> abstract getEmbeddingDimensions(model: Model): Promise<number>
public getBaseURL(): string { public getBaseURL(): string {

View File

@ -4,7 +4,7 @@ import i18n from '@renderer/i18n'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
import { EVENT_NAMES } from '@renderer/services/EventService' import { EVENT_NAMES } from '@renderer/services/EventService'
import { filterContextMessages } from '@renderer/services/MessagesService' import { filterContextMessages } from '@renderer/services/MessagesService'
import { Assistant, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types' import { Assistant, FileTypes, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
import { removeQuotes } from '@renderer/utils' import { removeQuotes } from '@renderer/utils'
import { last, takeRight } from 'lodash' import { last, takeRight } from 'lodash'
import OpenAI, { AzureOpenAI } from 'openai' import OpenAI, { AzureOpenAI } from 'openai'
@ -345,6 +345,7 @@ export default class OpenAIProvider extends BaseProvider {
} }
public async generateImage({ public async generateImage({
model,
prompt, prompt,
negativePrompt, negativePrompt,
imageSize, imageSize,
@ -352,30 +353,23 @@ export default class OpenAIProvider extends BaseProvider {
seed, seed,
numInferenceSteps, numInferenceSteps,
guidanceScale, guidanceScale,
signal signal,
}: { promptEnhancement
prompt: string }: GenerateImageParams): Promise<string[]> {
negativePrompt?: string
imageSize: string
batchSize: number
seed?: string
numInferenceSteps: number
guidanceScale: number
signal?: AbortSignal
}): Promise<string[]> {
const response = (await this.sdk.request({ const response = (await this.sdk.request({
method: 'post', method: 'post',
path: '/images/generations', path: '/images/generations',
signal, signal,
body: { body: {
model: 'stabilityai/stable-diffusion-3-5-large', model,
prompt, prompt,
negative_prompt: negativePrompt, negative_prompt: negativePrompt,
image_size: imageSize, image_size: imageSize,
batch_size: batchSize, batch_size: batchSize,
seed: seed ? parseInt(seed) : undefined, seed: seed ? parseInt(seed) : undefined,
num_inference_steps: numInferenceSteps, num_inference_steps: numInferenceSteps,
guidance_scale: guidanceScale guidance_scale: guidanceScale,
prompt_enhancement: promptEnhancement
} }
})) as { data: Array<{ url: string }> } })) as { data: Array<{ url: string }> }

View File

@ -112,6 +112,7 @@ export type Suggestion = {
export interface Painting { export interface Painting {
id: string id: string
model?: string
urls: string[] urls: string[]
files: FileType[] files: FileType[]
prompt?: string prompt?: string
@ -121,7 +122,7 @@ export interface Painting {
seed?: string seed?: string
steps?: number steps?: number
guidanceScale?: number guidanceScale?: number
model?: string promptEnhancement?: boolean
} }
export type MinAppType = { export type MinAppType = {
@ -224,3 +225,16 @@ export type KnowledgeBaseParams = {
apiVersion?: string apiVersion?: string
baseURL: string baseURL: string
} }
export type GenerateImageParams = {
model: string
prompt: string
negativePrompt?: string
imageSize: string
batchSize: number
seed?: string
numInferenceSteps: number
guidanceScale: number
signal?: AbortSignal
promptEnhancement?: boolean
}