diff --git a/ts/examples/chat/src/memory/knowproMemory.ts b/ts/examples/chat/src/memory/knowproMemory.ts index 28762f2a..0b50bb0b 100644 --- a/ts/examples/chat/src/memory/knowproMemory.ts +++ b/ts/examples/chat/src/memory/knowproMemory.ts @@ -165,8 +165,7 @@ export async function createKnowproCommands( description: "Search current knowPro conversation by terms", options: { maxToDisplay: argNum("Maximum matches to display", 25), - type: arg("Knowledge type"), - speaker: arg("Speaker"), + ktype: arg("Knowledge type"), }, }; } @@ -190,10 +189,11 @@ export async function createKnowproCommands( `Searching ${conversation.nameTag}...`, ); - const matches = await kp.searchConversation(conversation, terms, { - type: namedArgs.type, - speaker: namedArgs.speaker, - }); + const matches = await kp.searchConversation( + conversation, + terms, + filterFromArgs(namedArgs), + ); if (matches === undefined || matches.size === 0) { context.printer.writeLine("No matches"); return; @@ -210,6 +210,26 @@ export async function createKnowproCommands( } } + function filterFromArgs(namedArgs: NamedArgs) { + let filter: kp.SearchFilter = { type: namedArgs.ktype }; + let argCopy = { ...namedArgs }; + delete argCopy.maxToDisplay; + delete argCopy.ktype; + let keys = Object.keys(argCopy); + if (keys.length > 0) { + for (const key of keys) { + const value = argCopy[key]; + if (typeof value === "function") { + delete argCopy[key]; + } + } + if (Object.keys(argCopy).length > 0) { + filter.propertiesToMatch = argCopy; + } + } + return filter; + } + function entitiesDef(): CommandMetadata { return { description: "Display entities in current conversation", diff --git a/ts/packages/knowPro/src/accumulators.ts b/ts/packages/knowPro/src/accumulators.ts index 39b4109b..d44aabd0 100644 --- a/ts/packages/knowPro/src/accumulators.ts +++ b/ts/packages/knowPro/src/accumulators.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -import { collections, createTopNList } from "typeagent"; +import { createTopNList } from "typeagent"; import { IMessage, KnowledgeType, @@ -167,7 +167,7 @@ export class MatchAccumulator { } export class SemanticRefAccumulator extends MatchAccumulator { - constructor(public queryTermMatches = new QueryTermAccumulator()) { + constructor(public queryTermMatches = new TermMatchAccumulator()) { super(); } @@ -289,77 +289,41 @@ export class SemanticRefAccumulator extends MatchAccumulator { } } -export class QueryTermAccumulator { +export class TermMatchAccumulator { constructor( public termMatches: Set = new Set(), - public relatedTermToTerms: Map> = new Map< + // Related terms work 'on behalf' of a primary term + // For each related term, we track the primary terms it matched on behalf of + public relatedTermMatchedFor: Map> = new Map< string, Set >(), ) {} - public add(term: Term, relatedTerm?: Term) { - this.termMatches.add(term.text); + public add(primaryTerm: Term, relatedTerm?: Term) { + this.termMatches.add(primaryTerm.text); if (relatedTerm !== undefined) { - let relatedTermToTerms = this.relatedTermToTerms.get( - relatedTerm.text, - ); - if (relatedTermToTerms === undefined) { - relatedTermToTerms = new Set(); - this.relatedTermToTerms.set( - relatedTerm.text, - relatedTermToTerms, - ); + // Related term matched on behalf of term + let primaryTerms = this.relatedTermMatchedFor.get(relatedTerm.text); + if (primaryTerms === undefined) { + primaryTerms = new Set(); + this.relatedTermMatchedFor.set(relatedTerm.text, primaryTerms); } - relatedTermToTerms.add(term.text); + // Track that this related term matched on behalf of term + primaryTerms.add(primaryTerm.text); } } - public matched(testText: string | string[], expectedText: string): boolean { - if (Array.isArray(testText)) { - if (testText.length > 0) { - for (const text of testText) { - if (this.matched(text, expectedText)) { - return true; - } - } - } - return false; - } - - if (collections.stringEquals(testText, expectedText, false)) { + public has(text: string, includeRelated: boolean = true): boolean { + if (this.termMatches.has(text)) { return true; } - - // Maybe the test text matched a related term. - // If so, the matching related term should have matched *on behalf* of - // of expectedTerm - const relatedTermToTerms = this.relatedTermToTerms.get(testText); - return relatedTermToTerms !== undefined - ? relatedTermToTerms.has(expectedText) - : false; + return includeRelated ? this.relatedTermMatchedFor.has(text) : false; } - public didValueMatch( - obj: Record, - key: string, - expectedValue: string, - ): boolean { - const value = obj[key]; - if (value === undefined) { - return false; - } - if (Array.isArray(value)) { - for (const item of value) { - if (this.didValueMatch(item, key, expectedValue)) { - return true; - } - } - return false; - } else { - const stringValue = value.toString().toLowerCase(); - return this.matched(stringValue, expectedValue); - } + public hasRelatedMatch(primaryTerm: string, relatedTerm: string): boolean { + let primaryTerms = this.relatedTermMatchedFor.get(relatedTerm); + return primaryTerms?.has(primaryTerm) ?? false; } } @@ -378,6 +342,7 @@ export class TextRangeAccumulator { if (textRanges === undefined) { textRanges = [textRange]; } + // Future: Merge ranges textRanges.push(textRange); } diff --git a/ts/packages/knowPro/src/conversationIndex.ts b/ts/packages/knowPro/src/conversationIndex.ts index d716c3ff..8d5c422e 100644 --- a/ts/packages/knowPro/src/conversationIndex.ts +++ b/ts/packages/knowPro/src/conversationIndex.ts @@ -19,20 +19,17 @@ import { openai } from "aiclient"; import { Result } from "typechat"; import { async } from "typeagent"; -function addFacet( - facet: conversation.Facet | undefined, - refIndex: number, - semanticRefIndex: ITermToSemanticRefIndex, -) { - if (facet !== undefined) { - semanticRefIndex.addTerm(facet.name, refIndex); - if (facet.value !== undefined) { - semanticRefIndex.addTerm( - conversation.knowledgeValueToString(facet.value), - refIndex, - ); - } - } +function createKnowledgeModel() { + const chatModelSettings = openai.apiSettingsFromEnv( + openai.ModelType.Chat, + undefined, + "GPT_4_O", + ); + chatModelSettings.retryPauseMs = 10000; + const chatModel = openai.createJsonChatModel(chatModelSettings, [ + "chatExtractor", + ]); + return chatModel; } function textLocationFromLocation( @@ -52,17 +49,20 @@ function textRangeFromLocation( }; } -function createKnowledgeModel() { - const chatModelSettings = openai.apiSettingsFromEnv( - openai.ModelType.Chat, - undefined, - "GPT_4_O", - ); - chatModelSettings.retryPauseMs = 10000; - const chatModel = openai.createJsonChatModel(chatModelSettings, [ - "chatExtractor", - ]); - return chatModel; +function addFacet( + facet: conversation.Facet | undefined, + refIndex: number, + semanticRefIndex: ITermToSemanticRefIndex, +) { + if (facet !== undefined) { + semanticRefIndex.addTerm(facet.name, refIndex); + if (facet.value !== undefined) { + semanticRefIndex.addTerm( + conversation.knowledgeValueToString(facet.value), + refIndex, + ); + } + } } export function addEntityToIndex( diff --git a/ts/packages/knowPro/src/dataFormat.ts b/ts/packages/knowPro/src/dataFormat.ts index b36074fe..b89cbd76 100644 --- a/ts/packages/knowPro/src/dataFormat.ts +++ b/ts/packages/knowPro/src/dataFormat.ts @@ -67,7 +67,7 @@ export interface ITopic { text: string; } -type ITag = ITopic; +export type ITag = ITopic; export interface IConversation { nameTag: string; diff --git a/ts/packages/knowPro/src/query.ts b/ts/packages/knowPro/src/query.ts index 021e0703..18abf0cc 100644 --- a/ts/packages/knowPro/src/query.ts +++ b/ts/packages/knowPro/src/query.ts @@ -4,8 +4,10 @@ import { IConversation, IMessage, + ITag, ITermToRelatedTermsIndex, ITermToSemanticRefIndex, + ITopic, KnowledgeType, QueryTerm, SemanticRef, @@ -17,7 +19,7 @@ import * as knowLib from "knowledge-processor"; import { Match, MatchAccumulator, - QueryTermAccumulator, + TermMatchAccumulator, SemanticRefAccumulator, TextRangeAccumulator, } from "./accumulators.js"; @@ -327,7 +329,7 @@ export class WhereSemanticRefExpr private evalPredicates( context: QueryEvalContext, - queryTermMatches: QueryTermAccumulator, + queryTermMatches: TermMatchAccumulator, predicates: IQuerySemanticRefPredicate[], match: Match, ) { @@ -344,7 +346,7 @@ export class WhereSemanticRefExpr export interface IQuerySemanticRefPredicate { eval( context: QueryEvalContext, - termMatches: QueryTermAccumulator, + termMatches: TermMatchAccumulator, semanticRef: SemanticRef, ): boolean; } @@ -354,87 +356,201 @@ export class KnowledgeTypePredicate implements IQuerySemanticRefPredicate { public eval( context: QueryEvalContext, - termMatches: QueryTermAccumulator, + termMatches: TermMatchAccumulator, semanticRef: SemanticRef, ): boolean { return semanticRef.knowledgeType === this.type; } } -export class EntityPredicate implements IQuerySemanticRefPredicate { +export class PropertyMatchPredicate implements IQuerySemanticRefPredicate { constructor( - public type: string | undefined, - public name: string | undefined, - public facetName: string | undefined, + public nameValues: Record, + public matchAll: boolean = true, ) {} public eval( context: QueryEvalContext, - termMatches: QueryTermAccumulator, + termMatches: TermMatchAccumulator, semanticRef: SemanticRef, ): boolean { - if (semanticRef.knowledgeType !== "entity") { - return false; + for (const name of Object.keys(this.nameValues)) { + const value = this.nameValues[name]; + if ( + !matchSemanticRefProperty( + termMatches, + semanticRef, + name, + value, + ) && + this.matchAll + ) { + return false; + } } - const entity = - semanticRef.knowledge as knowLib.conversation.ConcreteEntity; - return ( - isPropertyMatch(termMatches, entity.type, this.type) && - isPropertyMatch(termMatches, entity.name, this.name) && - this.matchFacet(termMatches, entity, this.facetName) - ); + return true; } +} - private matchFacet( - termMatches: QueryTermAccumulator, - entity: knowLib.conversation.ConcreteEntity, - facetName?: string | undefined, - ): boolean { - if (facetName === undefined || entity.facets === undefined) { - return false; - } +export function matchSemanticRefProperty( + termMatches: TermMatchAccumulator, + semanticRef: SemanticRef, + propertyName: string, + value: string, +) { + switch (semanticRef.knowledgeType) { + default: + break; + case "entity": + return matchEntityProperty( + termMatches, + semanticRef.knowledge as knowLib.conversation.ConcreteEntity, + propertyName, + value, + ); + case "action": + return matchActionProperty( + termMatches, + semanticRef.knowledge as knowLib.conversation.Action, + propertyName, + value, + ); + case "topic": + return matchTopicProperty( + termMatches, + semanticRef.knowledge as ITopic, + propertyName, + value, + ); + case "tag": + return matchTagProperty( + termMatches, + semanticRef.knowledge as ITag, + propertyName, + value, + ); + } + return false; +} + +export function matchEntityProperty( + termMatches: TermMatchAccumulator, + entity: knowLib.conversation.ConcreteEntity, + propertyName: string, + value: string, +) { + if (propertyName === "name") { + return matchText(termMatches, value, entity.name); + } else if (propertyName === "type") { + return matchTextOneOf(termMatches, value, entity.type); + } else if (entity.facets !== undefined) { + // try facets for (const facet of entity.facets) { - if (isPropertyMatch(termMatches, facet.name, facetName)) { + if ( + matchText(termMatches, propertyName, facet.name) && + matchText( + termMatches, + value, + knowLib.conversation.knowledgeValueToString(facet.value), + ) + ) { return true; } } + } + return false; +} + +export type ActionPropertyName = + | "verb" + | "subject" + | "object" + | "indirectObject" + | string; + +export function matchActionProperty( + termMatches: TermMatchAccumulator, + action: knowLib.conversation.Action, + propertyName: ActionPropertyName, + value: string, +): boolean { + switch (propertyName) { + default: + break; + case "verb": + return matchTextOneOf(termMatches, value, action.verbs); + case "subject": + return matchText(termMatches, value, action.subjectEntityName); + case "object": + return matchText(termMatches, value, action.objectEntityName); + case "indirectObject": + return matchText( + termMatches, + value, + action.indirectObjectEntityName, + ); + } + return false; +} + +export function matchTopicProperty( + termMatches: TermMatchAccumulator, + topic: ITopic, + propertyName: string, + value: string, +) { + if (propertyName !== "topic") { return false; } + return matchText(termMatches, value, topic.text); } -export class ActionPredicate implements IQuerySemanticRefPredicate { - constructor( - public subjectEntityName?: string | undefined, - public objectEntityName?: string | undefined, - ) {} +export function matchTagProperty( + termMatches: TermMatchAccumulator, + tag: ITag, + propertyName: string, + value: string, +) { + if (propertyName !== "tag") { + return false; + } + return matchText(termMatches, value, tag.text); +} - public eval( - context: QueryEvalContext, - termMatches: QueryTermAccumulator, - semanticRef: SemanticRef, - ): boolean { - if (semanticRef.knowledgeType !== "action") { - return false; +function matchText( + termMatches: TermMatchAccumulator, + expected: string, + actual: string | undefined, +): boolean { + if (actual === undefined) { + return false; + } + return ( + expected === "*" || + collections.stringEquals(expected, actual, false) || + termMatches.hasRelatedMatch(expected, actual) + ); +} + +function matchTextOneOf( + termMatches: TermMatchAccumulator, + expected: string, + actual: string[] | undefined, +) { + if (actual !== undefined) { + for (const text of actual) { + if (matchText(termMatches, expected, text)) { + return true; + } } - const action = semanticRef.knowledge as knowLib.conversation.Action; - return ( - isPropertyMatch( - termMatches, - action.subjectEntityName, - this.subjectEntityName, - ) && - isPropertyMatch( - termMatches, - action.objectEntityName, - this.objectEntityName, - ) - ); } + return true; } export class ScopeExpr implements IQueryOpExpr { constructor( public sourceExpr: IQueryOpExpr, + // Predicates that identity what is in scope public predicates: IQuerySemanticRefPredicate[], ) {} @@ -466,7 +582,7 @@ export class ScopeExpr implements IQueryOpExpr { private evalPredicates( context: QueryEvalContext, - queryTermMatches: QueryTermAccumulator, + queryTermMatches: TermMatchAccumulator, predicates: IQuerySemanticRefPredicate[], semanticRef: SemanticRef, ) { @@ -478,14 +594,3 @@ export class ScopeExpr implements IQueryOpExpr { return false; } } - -function isPropertyMatch( - termMatches: QueryTermAccumulator, - testText: string | string[] | undefined, - expectedText: string | undefined, -) { - if (testText !== undefined && expectedText !== undefined) { - return termMatches.matched(testText, expectedText); - } - return testText === undefined && expectedText === undefined; -} diff --git a/ts/packages/knowPro/src/search.ts b/ts/packages/knowPro/src/search.ts index 47fc99f9..5a676d41 100644 --- a/ts/packages/knowPro/src/search.ts +++ b/ts/packages/knowPro/src/search.ts @@ -7,6 +7,7 @@ import { KnowledgeType, QueryTerm, ScoredSemanticRef, + Term, } from "./dataFormat.js"; import * as q from "./query.js"; @@ -16,8 +17,8 @@ export type SearchResult = { }; export type SearchFilter = { - type?: KnowledgeType; - speaker?: string; + type?: KnowledgeType | undefined; + propertiesToMatch?: Record; }; /** * Searches conversation for terms @@ -36,15 +37,10 @@ export async function searchConversation( if (!q.isConversationSearchable(conversation)) { return undefined; } - - const context = new q.QueryEvalContext(conversation); - const query = createTermSearchQuery( - conversation, - terms, - filter, - maxMatches, - ); - return toGroupedSearchResults(await query.eval(context)); + const queryBuilder = new SearchQueryBuilder(conversation); + const query = queryBuilder.compile(terms, filter, maxMatches); + const queryResults = await query.eval(new q.QueryEvalContext(conversation)); + return toGroupedSearchResults(queryResults); } export async function searchConversationExact( @@ -70,55 +66,121 @@ export async function searchConversationExact( }; } -function createTermSearchQuery( - conversation: IConversation, - terms: QueryTerm[], - filter?: SearchFilter, - maxMatches?: number, - minHitCount?: number, -) { - let where: q.IQuerySemanticRefPredicate[] | undefined; - if (filter !== undefined) { - where = []; +class SearchQueryBuilder { + constructor(public conversation: IConversation) {} + + public compile( + terms: QueryTerm[], + filter?: SearchFilter, + maxMatches?: number, + ) { + this.prepareTerms(terms, filter); + + let select = this.compileSelect(terms, filter); + const query = new q.SelectTopNKnowledgeGroupExpr( + new q.GroupByKnowledgeTypeExpr(select), + maxMatches, + ); + return query; + } + + private compileSelect(terms: QueryTerm[], filter?: SearchFilter) { + const queryTerms = new q.QueryTermsExpr(terms); + let termsMatchExpr: q.IQueryOpExpr = + new q.TermsMatchExpr( + this.conversation.relatedTermsIndex !== undefined + ? new q.ResolveRelatedTermsExpr(queryTerms) + : queryTerms, + ); + // Always apply "tag match" scope... all text ranges that matched tags.. are in scope + termsMatchExpr = new q.ScopeExpr(termsMatchExpr, [ + new q.KnowledgeTypePredicate("tag"), + ]); + if (filter !== undefined) { + // Where clause + termsMatchExpr = new q.WhereSemanticRefExpr( + termsMatchExpr, + this.compileFilter(filter), + ); + } + return termsMatchExpr; + } + + private compileFilter( + filter: SearchFilter, + ): q.IQuerySemanticRefPredicate[] { + let predicates: q.IQuerySemanticRefPredicate[] = []; if (filter.type) { - where.push(new q.KnowledgeTypePredicate(filter.type)); + predicates.push(new q.KnowledgeTypePredicate(filter.type)); } - if (filter.speaker) { - where.push(new q.ActionPredicate(filter.speaker)); + if (filter.propertiesToMatch) { + predicates.push( + new q.PropertyMatchPredicate(filter.propertiesToMatch), + ); } + return predicates; } - const query = new q.SelectTopNKnowledgeGroupExpr( - new q.GroupByKnowledgeTypeExpr( - createTermsMatch(conversation, terms, where), - ), - maxMatches, - minHitCount, - ); - return query; -} -function createTermsMatch( - conversation: IConversation, - terms: QueryTerm[], - wherePredicates?: q.IQuerySemanticRefPredicate[] | undefined, -) { - const queryTerms = new q.QueryTermsExpr(terms); - let termsMatchExpr: q.IQueryOpExpr = - new q.TermsMatchExpr( - conversation.relatedTermsIndex !== undefined - ? new q.ResolveRelatedTermsExpr(queryTerms) - : queryTerms, - ); - termsMatchExpr = new q.ScopeExpr(termsMatchExpr, [ - new q.KnowledgeTypePredicate("tag"), - ]); - if (wherePredicates !== undefined && wherePredicates.length > 0) { - termsMatchExpr = new q.WhereSemanticRefExpr( - termsMatchExpr, - wherePredicates, - ); + private prepareTerms(queryTerms: QueryTerm[], filter?: SearchFilter): void { + const termText = new Set(); + termText.add("*"); + let i = 0; + // Prepare terms and remove duplicates + while (i < queryTerms.length) { + const queryTerm = queryTerms[i]; + this.prepareTerm(queryTerm.term); + if (termText.has(queryTerm.term.text)) { + // Duplicate + queryTerms.splice(i, 1); + } else { + if (queryTerm.relatedTerms !== undefined) { + queryTerm.relatedTerms.forEach((t) => this.prepareTerm(t)); + } + termText.add(queryTerm.term.text); + ++i; + } + } + this.prepareFilter(filter); + // Ensure that all filter name values are also query terms + if (filter !== undefined && filter.propertiesToMatch) { + for (const key of Object.keys(filter.propertiesToMatch)) { + if ( + !termText.has(key) && + !SearchQueryBuilder.reservedPropertyNames.has(key) + ) { + queryTerms.push({ term: { text: key } }); + } + const value = filter.propertiesToMatch[key]; + if (!termText.has(value)) { + queryTerms.push({ term: { text: value } }); + } + } + } } - return termsMatchExpr; + + private prepareTerm(term: Term) { + term.text = term.text.toLowerCase(); + } + + private prepareFilter(filter?: SearchFilter) { + if (filter !== undefined && filter.propertiesToMatch) { + for (const key of Object.keys(filter.propertiesToMatch)) { + filter.propertiesToMatch[key] = + filter.propertiesToMatch[key].toLowerCase(); + } + } + } + + static reservedPropertyNames = new Set([ + "name", + "type", + "topic", + "tag", + "verb", + "subject", + "object", + "indirectObject ", + ]); } function toGroupedSearchResults(