diff --git a/examples/orm/iterator.py b/examples/orm/iterator.py index f9fa496d5..e4a231b36 100644 --- a/examples/orm/iterator.py +++ b/examples/orm/iterator.py @@ -18,7 +18,7 @@ PICTURE = "picture" CONSISTENCY_LEVEL = "Eventually" LIMIT = 5 -NUM_ENTITIES = 1000 +NUM_ENTITIES = 10000 DIM = 8 CLEAR_EXIST = True @@ -26,23 +26,6 @@ log = logging.getLogger(__name__) log.setLevel(logging.INFO) # Set the log level to INFO -# Create a console handler and set its level to INFO -console_handler = logging.StreamHandler() -console_handler.setLevel(logging.INFO) - -# Create a formatter for the console output -formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - -# Add the formatter to the handler -console_handler.setFormatter(formatter) - -# Add the handler to the logger (this will apply globally) -log.addHandler(console_handler) - -# Now, configure the root logger to apply to the entire app (including your package) -logging.getLogger().setLevel(logging.INFO) # Set the root logger level to INFO -logging.getLogger().addHandler(console_handler) # Attach the handler to the root logger - def re_create_collection(prepare_new_data: bool): if prepare_new_data: @@ -51,8 +34,7 @@ def re_create_collection(prepare_new_data: bool): print(f"dropped existed collection{COLLECTION_NAME}") fields = [ - FieldSchema(name=USER_ID, dtype=DataType.VARCHAR, is_primary=True, - auto_id=False, max_length=MAX_LENGTH), + FieldSchema(name=USER_ID, dtype=DataType.INT64, is_primary=True, auto_id=False), FieldSchema(name=AGE, dtype=DataType.INT64), FieldSchema(name=DEPOSIT, dtype=DataType.DOUBLE), FieldSchema(name=PICTURE, dtype=DataType.FLOAT_VECTOR, dim=DIM) @@ -80,10 +62,9 @@ def random_pk(filter_set: set, lower_bound: int, upper_bound: int) -> str: def insert_data(collection): rng = np.random.default_rng(seed=19530) batch_count = 5 - filter_set: set = {} for i in range(batch_count): entities = [ - [random_pk(filter_set, 0, batch_count * NUM_ENTITIES) for _ in range(NUM_ENTITIES)], + [i for i in range(NUM_ENTITIES*i, NUM_ENTITIES*(i + 1))], [int(ni % 100) for ni in range(NUM_ENTITIES)], [float(ni) for ni in range(NUM_ENTITIES)], rng.random((NUM_ENTITIES, DIM)), @@ -117,7 +98,7 @@ def query_iterate_collection_no_offset(collection): query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE], offset=0, batch_size=5, consistency_level=CONSISTENCY_LEVEL, - reduce_stop_for_best="false", print_iterator_cursor=False, + reduce_stop_for_best="false", iterator_cp_file="/tmp/it_cp") no_best_ids: set = set({}) page_idx = 0 @@ -136,7 +117,7 @@ def query_iterate_collection_no_offset(collection): print("best---------------------------") query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE], offset=0, batch_size=5, consistency_level=CONSISTENCY_LEVEL, - reduce_stop_for_best="true", print_iterator_cursor=False, iterator_cp_file="/tmp/it_cp") + reduce_stop_for_best="true", iterator_cp_file="/tmp/it_cp") best_ids: set = set({}) page_idx = 0 @@ -160,7 +141,23 @@ def query_iterate_collection_no_offset(collection): def query_iterate_collection_with_offset(collection): expr = f"10 <= {AGE} <= 14" query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE], - offset=10, batch_size=50, consistency_level=CONSISTENCY_LEVEL, print_iterator_cursor=True) + offset=10, batch_size=50, consistency_level=CONSISTENCY_LEVEL) + page_idx = 0 + while True: + res = query_iterator.next() + if len(res) == 0: + print("query iteration finished, close") + query_iterator.close() + break + for i in range(len(res)): + print(res[i]) + page_idx += 1 + print(f"page{page_idx}-------------------------") + + +def query_iterate_collection_with_large_offset(collection): + query_iterator = collection.query_iterator(output_fields=[USER_ID, AGE], + offset=48000, batch_size=50, consistency_level=CONSISTENCY_LEVEL) page_idx = 0 while True: res = query_iterator.next() @@ -177,7 +174,7 @@ def query_iterate_collection_with_offset(collection): def query_iterate_collection_with_limit(collection): expr = f"10 <= {AGE} <= 44" query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE], - batch_size=80, limit=530, consistency_level=CONSISTENCY_LEVEL, print_iterator_cursor=True) + batch_size=80, limit=530, consistency_level=CONSISTENCY_LEVEL) page_idx = 0 while True: res = query_iterator.next() @@ -191,6 +188,8 @@ def query_iterate_collection_with_limit(collection): print(f"page{page_idx}-------------------------") + + def search_iterator_collection(collection): SEARCH_NQ = 1 DIM = 8 @@ -201,7 +200,7 @@ def search_iterator_collection(collection): "params": {"nprobe": 10, "radius": 1.0}, } search_iterator = collection.search_iterator(vectors_to_search, PICTURE, search_params, batch_size=500, - output_fields=[USER_ID], print_iterator_cursor=True) + output_fields=[USER_ID]) page_idx = 0 while True: res = search_iterator.next() @@ -225,7 +224,7 @@ def search_iterator_collection_with_limit(collection): "params": {"nprobe": 10, "radius": 1.0}, } search_iterator = collection.search_iterator(vectors_to_search, PICTURE, search_params, batch_size=200, limit=755, - output_fields=[USER_ID], print_iterator_cursor=True) + output_fields=[USER_ID]) page_idx = 0 while True: res = search_iterator.next() @@ -240,11 +239,12 @@ def search_iterator_collection_with_limit(collection): def main(): - prepare_new_data = True + prepare_new_data = False connections.connect("default", host=HOST, port=PORT) collection = re_create_collection(prepare_new_data) if prepare_new_data: collection = prepare_data(collection) + query_iterate_collection_with_large_offset(collection) query_iterate_collection_no_offset(collection) query_iterate_collection_with_offset(collection) query_iterate_collection_with_limit(collection) diff --git a/pymilvus/orm/constants.py b/pymilvus/orm/constants.py index 6862ab75f..d7b666501 100644 --- a/pymilvus/orm/constants.py +++ b/pymilvus/orm/constants.py @@ -48,7 +48,6 @@ REDUCE_STOP_FOR_BEST = "reduce_stop_for_best" ITERATOR_FIELD = "iterator" ITERATOR_SESSION_TS_FIELD = "iterator_session_ts" -PRINT_ITERATOR_CURSOR = "print_iterator_cursor" DEFAULT_MAX_L2_DISTANCE = 99999999.0 DEFAULT_MIN_IP_DISTANCE = -99999999.0 DEFAULT_MAX_HAMMING_DISTANCE = 99999999.0 diff --git a/pymilvus/orm/iterator.py b/pymilvus/orm/iterator.py index 15118523f..9764c1976 100644 --- a/pymilvus/orm/iterator.py +++ b/pymilvus/orm/iterator.py @@ -1,5 +1,6 @@ import datetime import logging +import time from copy import deepcopy from pathlib import Path from typing import Any, Callable, Dict, List, Optional, TypeVar, Union @@ -37,7 +38,6 @@ MILVUS_LIMIT, OFFSET, PARAMS, - PRINT_ITERATOR_CURSOR, RADIUS, RANGE_FILTER, REDUCE_STOP_FOR_BEST, @@ -52,8 +52,6 @@ QueryIterator = TypeVar("QueryIterator") SearchIterator = TypeVar("SearchIterator") -log = logging.getLogger(__name__) - def fall_back_to_latest_session_ts(): d = datetime.datetime.now() @@ -113,7 +111,6 @@ def __init__( self.__check_set_batch_size(batch_size) self._limit = limit self.__check_set_reduce_stop_for_best() - check_set_flag(self, "_print_iterator_cursor", self._kwargs, PRINT_ITERATOR_CURSOR) self._returned_count = 0 self.__setup__pk_prop() self.__set_up_expr(expr) @@ -131,18 +128,42 @@ def __seek_to_offset(self): if offset > 0: seek_params = self._kwargs.copy() seek_params[OFFSET] = 0 - seek_params[MILVUS_LIMIT] = offset - res = self._conn.query( - collection_name=self._collection_name, - expr=self._expr, - output_field=self._output_fields, - partition_name=self._partition_names, - timeout=self._timeout, - **seek_params, - ) - result_index = min(len(res), offset) - self.__update_cursor(res[:result_index]) + seek_params[ITERATOR_FIELD] = "False" + seek_params[REDUCE_STOP_FOR_BEST] = "False" + start_time = time.time() + + def seek_offset_by_batch(batch: int, expr: str) -> int: + seek_params[MILVUS_LIMIT] = batch + res = self._conn.query( + collection_name=self._collection_name, + expr=expr, + output_field=[], + partition_name=self._partition_names, + timeout=self._timeout, + **seek_params, + ) + self.__update_cursor(res) + return len(res) + + while offset > 0: + batch_size = min(MAX_BATCH_SIZE, offset) + next_expr = self.__setup_next_expr() + seeked_count = seek_offset_by_batch(batch_size, next_expr) + LOGGER.debug( + f"seeked offset, seek_expr:{next_expr} batch_size:{batch_size} seeked_count:{seeked_count}" + ) + if seeked_count == 0: + LOGGER.info( + "seek offset has drained all matched results for query iterator, break" + ) + break + offset -= seeked_count self._kwargs[OFFSET] = 0 + seek_offset_duration = time.time() - start_time + LOGGER.info( + f"Finish seek offset for query iterator, offset:{offset}, current_pk_cursor:{self._next_id}, " + f"duration:{seek_offset_duration}" + ) def __init_cp_file_handler(self) -> bool: mode = "w" @@ -170,14 +191,14 @@ def __save_pk_cursor(self): self._cp_file_handler = self._cp_file_path.open("w") self._buffer_cursor_lines_number = 0 self.__save_mvcc_ts() - log.warning( + LOGGER.warning( "iterator cp file is not existed any more, recreate for iteration, " "do not remove this file manually!" ) if self._buffer_cursor_lines_number >= 100: self._cp_file_handler.seek(0) self._cp_file_handler.truncate() - log.info( + LOGGER.info( "cursor lines in cp file has exceeded 100 lines, truncate the file and rewrite" ) self._buffer_cursor_lines_number = 0 @@ -229,7 +250,7 @@ def __setup_ts_by_request(self): if res.extra is not None: self._session_ts = res.extra.get(ITERATOR_SESSION_TS_FIELD, 0) if self._session_ts <= 0: - log.warning("failed to get mvccTs from milvus server, use client-side ts instead") + LOGGER.warning("failed to get mvccTs from milvus server, use client-side ts instead") self._session_ts = fall_back_to_latest_session_ts() self._kwargs[GUARANTEE_TIMESTAMP] = self._session_ts @@ -291,8 +312,7 @@ def next(self): else: iterator_cache.release_cache(self._cache_id_in_use) current_expr = self.__setup_next_expr() - if self._print_iterator_cursor: - log.info(f"query_iterator_next_expr:{current_expr}") + LOGGER.debug(f"query_iterator_next_expr:{current_expr}") res = self._conn.query( collection_name=self._collection_name, expr=current_expr, @@ -358,7 +378,7 @@ def close(self) -> None: def inner_close(): self._cp_file_handler.close() self._cp_file_path.unlink() - log.info(f"removed cp file:{self._cp_file_path_str} for query iterator") + LOGGER.info(f"removed cp file:{self._cp_file_path_str} for query iterator") io_operation( inner_close, f"failed to clear cp file:{self._cp_file_path_str} for query iterator" @@ -482,14 +502,13 @@ def __init__( self.__check_offset() self.__check_rm_range_search_parameters() self.__setup__pk_prop() - check_set_flag(self, "_print_iterator_cursor", self._kwargs, PRINT_ITERATOR_CURSOR) self.__init_search_iterator() def __init_search_iterator(self): init_page = self.__execute_next_search(self._param, self._expr, False) self._session_ts = init_page.get_session_ts() if self._session_ts <= 0: - log.warning("failed to set up mvccTs from milvus server, use client-side ts instead") + LOGGER.warning("failed to set up mvccTs from milvus server, use client-side ts instead") self._session_ts = fall_back_to_latest_session_ts() self._kwargs[GUARANTEE_TIMESTAMP] = self._session_ts if len(init_page) == 0: @@ -693,8 +712,7 @@ def __try_search_fill(self) -> SearchPage: def __execute_next_search( self, next_params: dict, next_expr: str, to_extend_batch: bool ) -> SearchPage: - if self._print_iterator_cursor: - log.info(f"search_iterator_next_expr:{next_expr}, next_params:{next_params}") + LOGGER.debug(f"search_iterator_next_expr:{next_expr}, next_params:{next_params}") res = self._conn.search( self._iterator_params["collection_name"], self._iterator_params["data"],