feat: paintings add prompt enhancement params
This commit is contained in:
parent
6384525e20
commit
038aa2d5cc
@ -961,6 +961,11 @@ export const TEXT_TO_IMAGES_MODELS = [
|
||||
}
|
||||
]
|
||||
|
||||
export const TEXT_TO_IMAGES_MODELS_SUPPORT_IMAGE_ENHANCEMENT = [
|
||||
'stabilityai/stable-diffusion-2-1',
|
||||
'stabilityai/stable-diffusion-xl-base-1.0'
|
||||
]
|
||||
|
||||
export function isTextToImageModel(model: Model): boolean {
|
||||
return TEXT_TO_IMAGE_REGEX.test(model.id)
|
||||
}
|
||||
|
||||
@ -14,6 +14,7 @@ export function usePaintings() {
|
||||
paintings,
|
||||
addPainting: () => {
|
||||
const newPainting: Painting = {
|
||||
model: TEXT_TO_IMAGES_MODELS[0].id,
|
||||
id: uuid(),
|
||||
urls: [],
|
||||
files: [],
|
||||
@ -24,7 +25,7 @@ export function usePaintings() {
|
||||
seed: generateRandomSeed(),
|
||||
steps: 25,
|
||||
guidanceScale: 4.5,
|
||||
model: TEXT_TO_IMAGES_MODELS[0].id
|
||||
promptEnhancement: true
|
||||
}
|
||||
dispatch(addPainting(newPainting))
|
||||
return newPainting
|
||||
|
||||
@ -279,7 +279,9 @@
|
||||
"regenerate.confirm": "This will replace your existing generated images. Do you want to continue?",
|
||||
"seed": "Seed",
|
||||
"seed_tip": "The same seed and prompt can produce similar images",
|
||||
"title": "Images"
|
||||
"title": "Images",
|
||||
"prompt_enhancement": "Prompt Enhancement",
|
||||
"prompt_enhancement_tip": "Rewrite prompts into detailed, model-friendly versions when switched on"
|
||||
},
|
||||
"provider": {
|
||||
"aihubmix": "AiHubMix",
|
||||
|
||||
@ -277,7 +277,9 @@
|
||||
"regenerate.confirm": "これにより、既存の生成画像が置き換えられます。続行しますか?",
|
||||
"seed": "シード",
|
||||
"seed_tip": "同じシードとプロンプトで似た画像を生成できます",
|
||||
"title": "画像"
|
||||
"title": "画像",
|
||||
"prompt_enhancement": "プロンプト強化",
|
||||
"prompt_enhancement_tip": "オンにすると、プロンプトを詳細でモデルに適したバージョンに書き直します"
|
||||
},
|
||||
"provider": {
|
||||
"aihubmix": "AiHubMix",
|
||||
|
||||
@ -279,7 +279,9 @@
|
||||
"regenerate.confirm": "Это заменит ваши существующие сгенерированные изображения. Хотите продолжить?",
|
||||
"seed": "Ключ генерации",
|
||||
"seed_tip": "Одинаковый ключ генерации и промпт могут производить похожие изображения",
|
||||
"title": "Изображения"
|
||||
"title": "Изображения",
|
||||
"prompt_enhancement": "Улучшение промпта",
|
||||
"prompt_enhancement_tip": "При включении переписывает промпт в более детальную, модель-ориентированную версию"
|
||||
},
|
||||
"provider": {
|
||||
"aihubmix": "AiHubMix",
|
||||
|
||||
@ -280,7 +280,9 @@
|
||||
"regenerate.confirm": "这将覆盖已生成的图片,是否继续?",
|
||||
"seed": "随机种子",
|
||||
"seed_tip": "相同的种子和提示词可以生成相似的图片",
|
||||
"title": "图片"
|
||||
"title": "图片",
|
||||
"prompt_enhancement": "提示词增强",
|
||||
"prompt_enhancement_tip": "开启后将提示重写为详细的、适合模型的版本"
|
||||
},
|
||||
"provider": {
|
||||
"aihubmix": "AiHubMix",
|
||||
|
||||
@ -279,7 +279,9 @@
|
||||
"regenerate.confirm": "這將覆蓋已生成的圖片,是否繼續?",
|
||||
"seed": "隨機種子",
|
||||
"seed_tip": "相同的種子和提示詞可以生成相似的圖片",
|
||||
"title": "繪圖"
|
||||
"title": "繪圖",
|
||||
"prompt_enhancement": "提示詞增強",
|
||||
"prompt_enhancement_tip": "開啟後將提示重寫為詳細的、適合模型的版本"
|
||||
},
|
||||
"provider": {
|
||||
"aihubmix": "AiHubMix",
|
||||
|
||||
@ -6,7 +6,7 @@ import ImageSize3_4 from '@renderer/assets/images/paintings/image-size-3-4.svg'
|
||||
import ImageSize9_16 from '@renderer/assets/images/paintings/image-size-9-16.svg'
|
||||
import ImageSize16_9 from '@renderer/assets/images/paintings/image-size-16-9.svg'
|
||||
import { Navbar, NavbarCenter, NavbarRight } from '@renderer/components/app/Navbar'
|
||||
import { VStack } from '@renderer/components/Layout'
|
||||
import { HStack, VStack } from '@renderer/components/Layout'
|
||||
import Scrollbar from '@renderer/components/Scrollbar'
|
||||
import TranslateButton from '@renderer/components/TranslateButton'
|
||||
import { isMac } from '@renderer/config/constant'
|
||||
@ -25,7 +25,7 @@ import { DEFAULT_PAINTING } from '@renderer/store/paintings'
|
||||
import { setGenerating } from '@renderer/store/runtime'
|
||||
import { FileType, Painting } from '@renderer/types'
|
||||
import { getErrorMessage } from '@renderer/utils'
|
||||
import { Button, Input, InputNumber, Radio, Select, Slider, Tooltip } from 'antd'
|
||||
import { Button, Input, InputNumber, Radio, Select, Slider, Switch, Tooltip } from 'antd'
|
||||
import TextArea from 'antd/es/input/TextArea'
|
||||
import { FC, useEffect, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
@ -149,8 +149,13 @@ const PaintingsPage: FC = () => {
|
||||
dispatch(setGenerating(true))
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
if (!painting.model) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const urls = await AI.generateImage({
|
||||
model: painting.model,
|
||||
prompt,
|
||||
negativePrompt: painting.negativePrompt || '',
|
||||
imageSize: painting.imageSize || '1024x1024',
|
||||
@ -158,7 +163,8 @@ const PaintingsPage: FC = () => {
|
||||
seed: painting.seed || undefined,
|
||||
numInferenceSteps: painting.steps || 25,
|
||||
guidanceScale: painting.guidanceScale || 4.5,
|
||||
signal: controller.signal
|
||||
signal: controller.signal,
|
||||
promptEnhancement: painting.promptEnhancement || false
|
||||
})
|
||||
|
||||
if (urls.length > 0) {
|
||||
@ -360,13 +366,15 @@ const PaintingsPage: FC = () => {
|
||||
<InfoIcon />
|
||||
</Tooltip>
|
||||
</SettingTitle>
|
||||
<SliderContainer>
|
||||
<Slider min={1} max={50} value={painting.steps} onChange={(v) => updatePaintingState({ steps: v })} />
|
||||
<InputNumber
|
||||
<StyledInputNumber
|
||||
min={1}
|
||||
max={50}
|
||||
value={painting.steps}
|
||||
onChange={(v) => updatePaintingState({ steps: v || 25 })}
|
||||
onChange={(v) => updatePaintingState({ steps: (v as number) || 25 })}
|
||||
/>
|
||||
</SliderContainer>
|
||||
|
||||
<SettingTitle style={{ marginBottom: 5, marginTop: 15 }}>
|
||||
{t('paintings.guidance_scale')}
|
||||
@ -374,6 +382,7 @@ const PaintingsPage: FC = () => {
|
||||
<InfoIcon />
|
||||
</Tooltip>
|
||||
</SettingTitle>
|
||||
<SliderContainer>
|
||||
<Slider
|
||||
min={1}
|
||||
max={20}
|
||||
@ -381,14 +390,14 @@ const PaintingsPage: FC = () => {
|
||||
value={painting.guidanceScale}
|
||||
onChange={(v) => updatePaintingState({ guidanceScale: v })}
|
||||
/>
|
||||
<InputNumber
|
||||
<StyledInputNumber
|
||||
min={1}
|
||||
max={20}
|
||||
step={0.1}
|
||||
value={painting.guidanceScale}
|
||||
onChange={(v) => updatePaintingState({ guidanceScale: v || 4.5 })}
|
||||
onChange={(v) => updatePaintingState({ guidanceScale: (v as number) || 4.5 })}
|
||||
/>
|
||||
|
||||
</SliderContainer>
|
||||
<SettingTitle style={{ marginBottom: 5, marginTop: 15 }}>
|
||||
{t('paintings.negative_prompt')}
|
||||
<Tooltip title={t('paintings.negative_prompt_tip')}>
|
||||
@ -400,6 +409,18 @@ const PaintingsPage: FC = () => {
|
||||
onChange={(e) => updatePaintingState({ negativePrompt: e.target.value })}
|
||||
rows={4}
|
||||
/>
|
||||
<SettingTitle style={{ marginBottom: 5, marginTop: 15 }}>
|
||||
{t('paintings.prompt_enhancement')}
|
||||
<Tooltip title={t('paintings.prompt_enhancement_tip')}>
|
||||
<InfoIcon />
|
||||
</Tooltip>
|
||||
</SettingTitle>
|
||||
<HStack>
|
||||
<Switch
|
||||
checked={painting.promptEnhancement}
|
||||
onChange={(checked) => updatePaintingState({ promptEnhancement: checked })}
|
||||
/>
|
||||
</HStack>
|
||||
</LeftContainer>
|
||||
<MainContainer>
|
||||
<Artboard
|
||||
@ -547,4 +568,18 @@ const InfoIcon = styled(QuestionCircleOutlined)`
|
||||
}
|
||||
`
|
||||
|
||||
const SliderContainer = styled.div`
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 16px;
|
||||
|
||||
.ant-slider {
|
||||
flex: 1;
|
||||
}
|
||||
`
|
||||
|
||||
const StyledInputNumber = styled(InputNumber)`
|
||||
width: 70px;
|
||||
`
|
||||
|
||||
export default PaintingsPage
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import BaseProvider from '@renderer/providers/BaseProvider'
|
||||
import ProviderFactory from '@renderer/providers/ProviderFactory'
|
||||
import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||
import { Assistant, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { CompletionsParams } from '.'
|
||||
@ -48,16 +48,7 @@ export default class AiProvider {
|
||||
return this.sdk.getApiKey()
|
||||
}
|
||||
|
||||
public async generateImage(params: {
|
||||
prompt: string
|
||||
negativePrompt: string
|
||||
imageSize: string
|
||||
batchSize: number
|
||||
seed?: string
|
||||
numInferenceSteps: number
|
||||
guidanceScale: number
|
||||
signal?: AbortSignal
|
||||
}): Promise<string[]> {
|
||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
return this.sdk.generateImage(params)
|
||||
}
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import { REFERENCE_PROMPT } from '@renderer/config/prompts'
|
||||
import { getOllamaKeepAliveTime } from '@renderer/hooks/useOllama'
|
||||
import { getKnowledgeReferences } from '@renderer/services/KnowledgeService'
|
||||
import store from '@renderer/store'
|
||||
import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||
import { Assistant, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||
import { delay, isJSON } from '@renderer/utils'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
@ -26,16 +26,7 @@ export default abstract class BaseProvider {
|
||||
abstract generateText({ prompt, content }: { prompt: string; content: string }): Promise<string>
|
||||
abstract check(): Promise<{ valid: boolean; error: Error | null }>
|
||||
abstract models(): Promise<OpenAI.Models.Model[]>
|
||||
abstract generateImage(_params: {
|
||||
prompt: string
|
||||
negativePrompt: string
|
||||
imageSize: string
|
||||
batchSize: number
|
||||
seed?: string
|
||||
numInferenceSteps: number
|
||||
guidanceScale: number
|
||||
signal?: AbortSignal
|
||||
}): Promise<string[]>
|
||||
abstract generateImage(params: GenerateImageParams): Promise<string[]>
|
||||
abstract getEmbeddingDimensions(model: Model): Promise<number>
|
||||
|
||||
public getBaseURL(): string {
|
||||
|
||||
@ -4,7 +4,7 @@ import i18n from '@renderer/i18n'
|
||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||
import { EVENT_NAMES } from '@renderer/services/EventService'
|
||||
import { filterContextMessages } from '@renderer/services/MessagesService'
|
||||
import { Assistant, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||
import { Assistant, FileTypes, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||
import { removeQuotes } from '@renderer/utils'
|
||||
import { last, takeRight } from 'lodash'
|
||||
import OpenAI, { AzureOpenAI } from 'openai'
|
||||
@ -345,6 +345,7 @@ export default class OpenAIProvider extends BaseProvider {
|
||||
}
|
||||
|
||||
public async generateImage({
|
||||
model,
|
||||
prompt,
|
||||
negativePrompt,
|
||||
imageSize,
|
||||
@ -352,30 +353,23 @@ export default class OpenAIProvider extends BaseProvider {
|
||||
seed,
|
||||
numInferenceSteps,
|
||||
guidanceScale,
|
||||
signal
|
||||
}: {
|
||||
prompt: string
|
||||
negativePrompt?: string
|
||||
imageSize: string
|
||||
batchSize: number
|
||||
seed?: string
|
||||
numInferenceSteps: number
|
||||
guidanceScale: number
|
||||
signal?: AbortSignal
|
||||
}): Promise<string[]> {
|
||||
signal,
|
||||
promptEnhancement
|
||||
}: GenerateImageParams): Promise<string[]> {
|
||||
const response = (await this.sdk.request({
|
||||
method: 'post',
|
||||
path: '/images/generations',
|
||||
signal,
|
||||
body: {
|
||||
model: 'stabilityai/stable-diffusion-3-5-large',
|
||||
model,
|
||||
prompt,
|
||||
negative_prompt: negativePrompt,
|
||||
image_size: imageSize,
|
||||
batch_size: batchSize,
|
||||
seed: seed ? parseInt(seed) : undefined,
|
||||
num_inference_steps: numInferenceSteps,
|
||||
guidance_scale: guidanceScale
|
||||
guidance_scale: guidanceScale,
|
||||
prompt_enhancement: promptEnhancement
|
||||
}
|
||||
})) as { data: Array<{ url: string }> }
|
||||
|
||||
|
||||
@ -112,6 +112,7 @@ export type Suggestion = {
|
||||
|
||||
export interface Painting {
|
||||
id: string
|
||||
model?: string
|
||||
urls: string[]
|
||||
files: FileType[]
|
||||
prompt?: string
|
||||
@ -121,7 +122,7 @@ export interface Painting {
|
||||
seed?: string
|
||||
steps?: number
|
||||
guidanceScale?: number
|
||||
model?: string
|
||||
promptEnhancement?: boolean
|
||||
}
|
||||
|
||||
export type MinAppType = {
|
||||
@ -224,3 +225,16 @@ export type KnowledgeBaseParams = {
|
||||
apiVersion?: string
|
||||
baseURL: string
|
||||
}
|
||||
|
||||
export type GenerateImageParams = {
|
||||
model: string
|
||||
prompt: string
|
||||
negativePrompt?: string
|
||||
imageSize: string
|
||||
batchSize: number
|
||||
seed?: string
|
||||
numInferenceSteps: number
|
||||
guidanceScale: number
|
||||
signal?: AbortSignal
|
||||
promptEnhancement?: boolean
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user