From ff01b21122e4835b599c3367afe9156d9f9c96db Mon Sep 17 00:00:00 2001 From: Terry Sutton Date: Fri, 23 Aug 2024 17:03:11 -0230 Subject: [PATCH] Add a replace code button to the ai assistant --- .../AiAssistantPanel/AiMessagePre.tsx | 27 +++++++++++++- .../interfaces/SQLEditor/SQLEditor.tsx | 37 +++++++++++++++---- .../interfaces/SQLEditor/SQLEditor.types.ts | 1 + .../interfaces/SQLEditor/SQLEditor.utils.ts | 4 ++ 4 files changed, 61 insertions(+), 8 deletions(-) diff --git a/apps/studio/components/interfaces/SQLEditor/AiAssistantPanel/AiMessagePre.tsx b/apps/studio/components/interfaces/SQLEditor/AiAssistantPanel/AiMessagePre.tsx index a08386377c240..790827a95813a 100644 --- a/apps/studio/components/interfaces/SQLEditor/AiAssistantPanel/AiMessagePre.tsx +++ b/apps/studio/components/interfaces/SQLEditor/AiAssistantPanel/AiMessagePre.tsx @@ -1,6 +1,6 @@ import { useTelemetryProps } from 'common' import { InsertCode, ReplaceCode } from 'icons' -import { Check, Copy } from 'lucide-react' +import { ArrowLeftRight, Check, Copy } from 'lucide-react' import { useRouter } from 'next/router' import { useEffect, useState } from 'react' import { format } from 'sql-formatter' @@ -113,6 +113,31 @@ export const AiMessagePre = ({ onDiff, children, className }: AiMessagePreProps) Replace code + + + + + + Overwrite code + + diff --git a/apps/studio/components/interfaces/SQLEditor/SQLEditor.tsx b/apps/studio/components/interfaces/SQLEditor/SQLEditor.tsx index 178c02b1dac60..6b8c9d87bba50 100644 --- a/apps/studio/components/interfaces/SQLEditor/SQLEditor.tsx +++ b/apps/studio/components/interfaces/SQLEditor/SQLEditor.tsx @@ -361,21 +361,34 @@ const SQLEditor = () => { setAiQueryCount((count) => count + 1) const existingValue = editorRef.current?.getValue() ?? '' - if (existingValue.length === 0) { - // if the editor is empty, just copy over the code + + const applyDirectly = (content: string) => { editorRef.current?.executeEdits('apply-ai-message', [ { - text: `${sqlAiDisclaimerComment}\n\n${sql}`, + text: content, range: editorModel.getFullModelRange(), }, ]) - } else { + } + + const setupDiffView = (original: string, modified: string) => { setSelectedMessage(id) - const currentSql = editorRef.current?.getValue() - const diff = { original: currentSql || '', modified: sql } - setSourceSqlDiff(diff) + setSourceSqlDiff({ original, modified }) setSelectedDiffType(diffType) } + + const sqlWithDisclaimer = `${sqlAiDisclaimerComment}\n\n${sql}` + + switch (true) { + case existingValue.length === 0: + applyDirectly(sqlWithDisclaimer) + break + case diffType === DiffType.Overwrite: + applyDirectly(sqlWithDisclaimer) + break + default: + setupDiffView(existingValue, sql) + } }, [setAiQueryCount] ) @@ -554,6 +567,12 @@ const SQLEditor = () => { return } + case DiffType.Overwrite: { + const transformedDiff = compareAsNewSnippet(sourceSqlDiff) + applyDiff(transformedDiff) + return + } + default: throw new Error(`Unknown diff type '${selectedDiffType}'`) } @@ -592,6 +611,10 @@ const SQLEditor = () => { return compareAsNewSnippet(sourceSqlDiff) } + case DiffType.Overwrite: { + return compareAsNewSnippet(sourceSqlDiff) + } + default: return { original: '', modified: '' } } diff --git a/apps/studio/components/interfaces/SQLEditor/SQLEditor.types.ts b/apps/studio/components/interfaces/SQLEditor/SQLEditor.types.ts index a01bbb7697c92..ca28fcbbf6b0d 100644 --- a/apps/studio/components/interfaces/SQLEditor/SQLEditor.types.ts +++ b/apps/studio/components/interfaces/SQLEditor/SQLEditor.types.ts @@ -31,4 +31,5 @@ export enum DiffType { Modification = 'modification', Addition = 'addition', NewSnippet = 'new-snippet', + Overwrite = 'overwrite', } diff --git a/apps/studio/components/interfaces/SQLEditor/SQLEditor.utils.ts b/apps/studio/components/interfaces/SQLEditor/SQLEditor.utils.ts index 4cce7028521f8..cd3dc0a8a5336 100644 --- a/apps/studio/components/interfaces/SQLEditor/SQLEditor.utils.ts +++ b/apps/studio/components/interfaces/SQLEditor/SQLEditor.utils.ts @@ -76,6 +76,8 @@ export function getDiffTypeButtonLabel(diffType: DiffType) { return 'Accept addition' case DiffType.NewSnippet: return 'Create new snippet' + case DiffType.Overwrite: + return 'Overwrite with snippet' default: throw new Error(`Unknown diff type '${diffType}'`) } @@ -89,6 +91,8 @@ export function getDiffTypeDropdownLabel(diffType: DiffType) { return 'Compare as addition' case DiffType.NewSnippet: return 'Compare as new snippet' + case DiffType.Overwrite: + return 'Overwrite as new snippet' default: throw new Error(`Unknown diff type '${diffType}'`) }