Skip to content

Commit

Permalink
feat: automatic canvas naming after initial skill invocation
Browse files Browse the repository at this point in the history
  • Loading branch information
mrcfps committed Feb 14, 2025
1 parent 6bbe17c commit 42b0a00
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 78 deletions.
5 changes: 5 additions & 0 deletions apps/api/src/canvas/canvas.dto.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ export interface DeleteCanvasNodesJobData {
entities: Entity[];
}

export interface AutoNameCanvasJobData {
uid: string;
canvasId: string;
}

export function canvasPO2DTO(canvas: CanvasModel): Canvas {
return {
...pick(canvas, ['canvasId', 'title', 'shareCode']),
Expand Down
13 changes: 11 additions & 2 deletions apps/api/src/canvas/canvas.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ import { Module } from '@nestjs/common';
import { BullModule } from '@nestjs/bullmq';
import { CanvasController } from './canvas.controller';
import { CanvasService } from './canvas.service';
import { ClearCanvasEntityProcessor, SyncCanvasEntityProcessor } from './canvas.processor';
import {
ClearCanvasEntityProcessor,
SyncCanvasEntityProcessor,
AutoNameCanvasProcessor,
} from './canvas.processor';
import { CollabModule } from '@/collab/collab.module';
import { QUEUE_DELETE_KNOWLEDGE_ENTITY } from '@/utils/const';
import { CommonModule } from '@/common/common.module';
Expand All @@ -18,7 +22,12 @@ import { MiscModule } from '@/misc/misc.module';
}),
],
controllers: [CanvasController],
providers: [CanvasService, SyncCanvasEntityProcessor, ClearCanvasEntityProcessor],
providers: [
CanvasService,
SyncCanvasEntityProcessor,
ClearCanvasEntityProcessor,
AutoNameCanvasProcessor,
],
exports: [CanvasService],
})
export class CanvasModule {}
26 changes: 24 additions & 2 deletions apps/api/src/canvas/canvas.processor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,16 @@ import { Logger } from '@nestjs/common';
import { Job } from 'bullmq';

import { CanvasService } from './canvas.service';
import { QUEUE_CLEAR_CANVAS_ENTITY, QUEUE_SYNC_CANVAS_ENTITY } from '@/utils/const';
import { DeleteCanvasNodesJobData, SyncCanvasEntityJobData } from './canvas.dto';
import {
QUEUE_CLEAR_CANVAS_ENTITY,
QUEUE_SYNC_CANVAS_ENTITY,
QUEUE_AUTO_NAME_CANVAS,
} from '@/utils/const';
import {
DeleteCanvasNodesJobData,
SyncCanvasEntityJobData,
AutoNameCanvasJobData,
} from './canvas.dto';

@Processor(QUEUE_SYNC_CANVAS_ENTITY)
export class SyncCanvasEntityProcessor extends WorkerHost {
Expand Down Expand Up @@ -46,3 +54,17 @@ export class ClearCanvasEntityProcessor extends WorkerHost {
}
}
}

@Processor(QUEUE_AUTO_NAME_CANVAS)
export class AutoNameCanvasProcessor extends WorkerHost {
private logger = new Logger(AutoNameCanvasProcessor.name);

constructor(private canvasService: CanvasService) {
super();
}

async process(job: Job<AutoNameCanvasJobData>) {
this.logger.log(`Processing auto name canvas job ${job.id} for canvas ${job.data.canvasId}`);
await this.canvasService.autoNameCanvasFromQueue(job.data);
}
}
102 changes: 47 additions & 55 deletions apps/api/src/canvas/canvas.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import { ConfigService } from '@nestjs/config';
import { ChatOpenAI } from '@langchain/openai';
import { HumanMessage } from '@langchain/core/messages';
import { SystemMessage } from '@langchain/core/messages';
import { AutoNameCanvasJobData } from './canvas.dto';

@Injectable()
export class CanvasService {
Expand Down Expand Up @@ -341,65 +342,44 @@ export class CanvasService {
throw new CanvasNotFoundError();
}

// Get all entities associated with the canvas
const relations = await this.prisma.canvasEntityRelation.findMany({
where: { canvasId, deletedAt: null },
const results = await this.prisma.actionResult.findMany({
select: { title: true, input: true, version: true, resultId: true },
where: { targetId: canvasId, targetType: 'canvas' },
});

// Collect content from all entities
const contentPromises = relations.map(async (relation) => {
switch (relation.entityType) {
case 'resource': {
const resource = await this.prisma.resource.findUnique({
select: {
title: true,
contentPreview: true,
},
where: { resourceId: relation.entityId },
});
return resource
? `Title: ${resource?.title}\nContent Preview: ${resource?.contentPreview}`
: '';
}
case 'document': {
const document = await this.prisma.document.findUnique({
select: {
title: true,
contentPreview: true,
},
where: { docId: relation.entityId },
});
return document
? `Title: ${document?.title}\nContent Preview: ${document?.contentPreview}`
: '';
}
case 'skillResponse': {
const result = await this.prisma.actionResult.findFirst({
select: { title: true, input: true, version: true },
where: { resultId: relation.entityId },
});
if (!result) {
return '';
}
const steps = await this.prisma.actionStep.findMany({
select: {
content: true,
},
where: { resultId: relation.entityId, version: result.version },
orderBy: { order: 'asc' },
});
const input = JSON.parse(result?.input ?? '{}');
const question = input?.query ?? result?.title;
const canvasContent = await Promise.all(
results.map(async (result) => {
const { resultId, version, input, title } = result;
const steps = await this.prisma.actionStep.findMany({
where: { resultId, version },
});
const parsedInput = JSON.parse(input ?? '{}');
const question = parsedInput?.query ?? title;

return `Question: ${question}\nAnswer: ${steps.map((s) => s.content.slice(0, 100)).join('\n')}`;
}
default:
return '';
}
});
return `Question: ${question}\nAnswer: ${steps.map((s) => s.content.slice(0, 100)).join('\n')}`;
}),
);

// If no action results, try to get all entities associated with the canvas
if (canvasContent.length === 0) {
const relations = await this.prisma.canvasEntityRelation.findMany({
where: { canvasId, entityType: { in: ['resource', 'document'] }, deletedAt: null },
});
const documents = await this.prisma.document.findMany({
select: { title: true, contentPreview: true },
where: { docId: { in: relations.map((r) => r.entityId) } },
});
const resources = await this.prisma.resource.findMany({
select: { title: true, contentPreview: true },
where: { resourceId: { in: relations.map((r) => r.entityId) } },
});
canvasContent.push(
...documents.map((d) => `Title: ${d.title}\nContent Preview: ${d.contentPreview}`),
...resources.map((r) => `Title: ${r.title}\nContent Preview: ${r.contentPreview}`),
);
}

const contents = await Promise.all(contentPromises);
const combinedContent = contents.filter(Boolean).join('\n');
const combinedContent = canvasContent.filter(Boolean).join('\n\n');

if (!combinedContent) {
return { title: '' };
Expand Down Expand Up @@ -434,4 +414,16 @@ export class CanvasService {

return { title: newTitle };
}

async autoNameCanvasFromQueue(jobData: AutoNameCanvasJobData) {
const { uid, canvasId } = jobData;
const user = await this.prisma.user.findFirst({ where: { uid } });
if (!user) {
this.logger.warn(`user not found for uid ${uid} when auto naming canvas: ${canvasId}`);
return;
}

const result = await this.autoNameCanvas(user, { canvasId, directUpdate: true });
this.logger.log(`Auto named canvas ${canvasId} with title: ${result.title}`);
}
}
2 changes: 2 additions & 0 deletions apps/api/src/skill/skill.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
QUEUE_SKILL,
QUEUE_SKILL_TIMEOUT_CHECK,
QUEUE_SYNC_REQUEST_USAGE,
QUEUE_AUTO_NAME_CANVAS,
} from '@/utils';
import { LabelModule } from '@/label/label.module';
import { SkillProcessor, SkillTimeoutCheckProcessor } from '@/skill/skill.processor';
Expand All @@ -34,6 +35,7 @@ import { MiscModule } from '@/misc/misc.module';
BullModule.registerQueue({ name: QUEUE_SKILL_TIMEOUT_CHECK }),
BullModule.registerQueue({ name: QUEUE_SYNC_TOKEN_USAGE }),
BullModule.registerQueue({ name: QUEUE_SYNC_REQUEST_USAGE }),
BullModule.registerQueue({ name: QUEUE_AUTO_NAME_CANVAS }),
],
providers: [SkillService, SkillProcessor, SkillTimeoutCheckProcessor],
controllers: [SkillController],
Expand Down
17 changes: 17 additions & 0 deletions apps/api/src/skill/skill.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ import {
QUEUE_SYNC_TOKEN_USAGE,
QUEUE_SKILL_TIMEOUT_CHECK,
QUEUE_SYNC_REQUEST_USAGE,
QUEUE_AUTO_NAME_CANVAS,
} from '@/utils';
import { InvokeSkillJobData, SkillTimeoutCheckJobData } from './skill.dto';
import { KnowledgeService } from '@/knowledge/knowledge.service';
Expand Down Expand Up @@ -93,6 +94,7 @@ import { CollabContext } from '@/collab/collab.dto';
import { DirectConnection } from '@hocuspocus/server';
import { modelInfoPO2DTO } from '@/misc/misc.dto';
import { MiscService } from '@/misc/misc.service';
import { AutoNameCanvasJobData } from '@/canvas/canvas.dto';

function validateSkillTriggerCreateParam(param: SkillTriggerCreateParam) {
if (param.triggerType === 'simpleEvent') {
Expand Down Expand Up @@ -129,6 +131,8 @@ export class SkillService {
@InjectQueue(QUEUE_SYNC_TOKEN_USAGE) private usageReportQueue: Queue<SyncTokenUsageJobData>,
@InjectQueue(QUEUE_SYNC_REQUEST_USAGE)
private requestUsageQueue: Queue<SyncRequestUsageJobData>,
@InjectQueue(QUEUE_AUTO_NAME_CANVAS)
private autoNameCanvasQueue: Queue<AutoNameCanvasJobData>,
) {
this.skillEngine = new SkillEngine(this.logger, this.buildReflyService(), {
defaultModel: this.config.get('skill.defaultModel'),
Expand Down Expand Up @@ -1036,6 +1040,19 @@ export class SkillService {

writeSSEResponse(res, { event: 'end', resultId, version });

// Check if we need to auto-name the target canvas
if (data.target?.entityType === 'canvas' && !result.errors.length) {
const canvas = await this.prisma.canvas.findFirst({
where: { canvasId: data.target.entityId, uid: user.uid },
});
if (canvas && !canvas.title) {
await this.autoNameCanvasQueue.add('autoNameCanvas', {
uid: user.uid,
canvasId: canvas.canvasId,
});
}
}

await this.requestUsageQueue.add('syncRequestUsage', {
uid: user.uid,
tier,
Expand Down
1 change: 1 addition & 0 deletions apps/api/src/utils/const.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export const QUEUE_SYNC_STORAGE_USAGE = 'syncStorageUsage';
export const QUEUE_SYNC_CANVAS_ENTITY = 'syncCanvasEntity';
export const QUEUE_CLEAR_CANVAS_ENTITY = 'clearCanvasEntity';
export const QUEUE_DELETE_KNOWLEDGE_ENTITY = 'deleteKnowledgeEntity';
export const QUEUE_AUTO_NAME_CANVAS = 'autoNameCanvas';

export const QUEUE_SEND_VERIFICATION_EMAIL = 'sendVerificationEmail';
export const QUEUE_CHECK_CANCELED_SUBSCRIPTIONS = 'checkCanceledSubscriptions';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { useSiderStoreShallow } from '@refly-packages/ai-workspace-common/stores
import { useTranslation } from 'react-i18next';
import { LOCALE } from '@refly/common-types';
import { useDebounce } from 'use-debounce';

import { useDebouncedCallback } from 'use-debounce';
import { MdOutlineImage, MdOutlineAspectRatio } from 'react-icons/md';
import { AiOutlineMenuUnfold } from 'react-icons/ai';
import { BiErrorCircle } from 'react-icons/bi';
Expand All @@ -23,6 +23,7 @@ import { CanvasRename } from './canvas-rename';
import { HoverCard } from '@refly-packages/ai-workspace-common/components/hover-card';
import { CanvasActionDropdown } from '@refly-packages/ai-workspace-common/components/workspace/canvas-list-modal/canvasActionDropdown';
import { useHoverCard } from '@refly-packages/ai-workspace-common/hooks/use-hover-card';
import { useHandleSiderData } from '@refly-packages/ai-workspace-common/hooks/use-handle-sider-data';

interface TopToolbarProps {
canvasId: string;
Expand All @@ -49,21 +50,34 @@ const CanvasTitle = memo(
updateCanvasTitle: state.updateCanvasTitle,
}));

const handleEditClick = () => {
const handleEditClick = useCallback(() => {
setIsModalOpen(true);
};

const handleModalOk = (newTitle: string) => {
if (newTitle?.trim()) {
syncTitleToYDoc(newTitle);
updateCanvasTitle(canvasId, newTitle);
setIsModalOpen(false);
}
};
}, []);

const handleModalOk = useCallback(
(newTitle: string) => {
if (newTitle?.trim()) {
syncTitleToYDoc(newTitle);
updateCanvasTitle(canvasId, newTitle);
setIsModalOpen(false);
}
},
[canvasId, syncTitleToYDoc, updateCanvasTitle],
);

const handleModalCancel = () => {
const handleModalCancel = useCallback(() => {
setIsModalOpen(false);
};
}, []);

const { getCanvasList } = useHandleSiderData();
const debouncedRefetchCanvasList = useDebouncedCallback(async () => {
await getCanvasList();
}, 500);

// Refetch canvas list when canvas title changes
useEffect(() => {
debouncedRefetchCanvasList();
}, [canvasTitle]);

return (
<>
Expand Down
12 changes: 6 additions & 6 deletions packages/ai-workspace-common/src/hooks/use-handle-sider-data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ export const useHandleSiderData = (initData?: boolean) => {

const [isLoadingCanvas, setIsLoadingCanvas] = useState(false);

const getCanvasList = async () => {
setIsLoadingCanvas(true);
const getCanvasList = async (setLoading?: boolean) => {
setLoading && setIsLoadingCanvas(true);
const { data: res, error } = await getClient().listCanvases({
query: { page: 1, pageSize: DATA_NUM },
});
setIsLoadingCanvas(false);
setLoading && setIsLoadingCanvas(false);
if (error) {
console.error('getCanvasList error', error);
return [];
Expand Down Expand Up @@ -85,13 +85,13 @@ export const useHandleSiderData = (initData?: boolean) => {
updateLibraryList(libraryList);
};

const loadSiderData = async () => {
getCanvasList();
const loadSiderData = async (setLoading?: boolean) => {
getCanvasList(setLoading);
};

useEffect(() => {
if (initData) {
loadSiderData();
loadSiderData(true);
}
}, []);

Expand Down

0 comments on commit 42b0a00

Please sign in to comment.