|
| 1 | +from typing import List, Dict |
| 2 | + |
| 3 | +import pandas as pd |
| 4 | + |
| 5 | +from reinvent_scoring.scoring.diversity_filters.curriculum_learning.column_names_enum import ColumnNamesEnum |
| 6 | +from reinvent_scoring.scoring.diversity_filters.curriculum_learning.memory_record_dto import MemoryRecordDTO |
| 7 | +from reinvent_scoring.scoring.score_summary import ComponentSummary |
| 8 | +from reinvent_scoring.scoring.enums.scoring_function_component_enum import ScoringFunctionComponentNameEnum |
| 9 | + |
| 10 | + |
| 11 | +class DiversityFilterMemory: |
| 12 | + |
| 13 | + def __init__(self): |
| 14 | + self._sf_component_name = ScoringFunctionComponentNameEnum() |
| 15 | + self._column_name = ColumnNamesEnum() |
| 16 | + df_dict = {self._column_name.STEP: [], self._column_name.SCAFFOLD: [], self._column_name.SMILES: [], |
| 17 | + self._column_name.METADATA: []} |
| 18 | + self._memory_dataframe = pd.DataFrame(df_dict) |
| 19 | + |
| 20 | + def update(self, dto: MemoryRecordDTO): |
| 21 | + component_scores = {c.parameters.name: float(c.total_score[dto.id]) for c in dto.components} |
| 22 | + component_scores = self._include_raw_score(dto.id, component_scores, dto.components) |
| 23 | + component_scores[self._sf_component_name.TOTAL_SCORE] = float(dto.score) |
| 24 | + if not self.smiles_exists(dto.smile): self._add_to_memory_dataframe(dto, component_scores) |
| 25 | + |
| 26 | + def _add_to_memory_dataframe(self, dto: MemoryRecordDTO, component_scores: Dict): |
| 27 | + data = [] |
| 28 | + headers = [] |
| 29 | + for name, score in component_scores.items(): |
| 30 | + headers.append(name) |
| 31 | + data.append(score) |
| 32 | + headers.append(self._column_name.STEP) |
| 33 | + data.append(dto.step) |
| 34 | + headers.append(self._column_name.SCAFFOLD) |
| 35 | + data.append(dto.scaffold) |
| 36 | + headers.append(self._column_name.SMILES) |
| 37 | + data.append(dto.smile) |
| 38 | + headers.append(self._column_name.METADATA) |
| 39 | + data.append(dto.loggable_data) |
| 40 | + new_data = pd.DataFrame([data], columns=headers) |
| 41 | + self._memory_dataframe = pd.concat([self._memory_dataframe, new_data], ignore_index=True, sort=False) |
| 42 | + |
| 43 | + def get_memory(self) -> pd.DataFrame: |
| 44 | + return self._memory_dataframe |
| 45 | + |
| 46 | + def set_memory(self, memory: pd.DataFrame): |
| 47 | + self._memory_dataframe = memory |
| 48 | + |
| 49 | + def smiles_exists(self, smiles: str): |
| 50 | + if len(self._memory_dataframe) == 0: |
| 51 | + return False |
| 52 | + return smiles in self._memory_dataframe[self._column_name.SMILES].values |
| 53 | + |
| 54 | + def scaffold_instances_count(self, scaffold: str): |
| 55 | + return (self._memory_dataframe[self._column_name.SCAFFOLD].values == scaffold).sum() |
| 56 | + |
| 57 | + def number_of_scaffolds(self): |
| 58 | + return len(set(self._memory_dataframe[self._column_name.SCAFFOLD].values)) |
| 59 | + |
| 60 | + def number_of_smiles(self): |
| 61 | + return len(set(self._memory_dataframe[self._column_name.SMILES].values)) |
| 62 | + |
| 63 | + def _include_raw_score(self, indx: int, component_scores: dict, components: List[ComponentSummary]): |
| 64 | + raw_scores = {f'raw_{c.parameters.name}': float(c.raw_score[indx]) for c in components if |
| 65 | + c.raw_score is not None} |
| 66 | + all_scores = {**component_scores, **raw_scores} |
| 67 | + return all_scores |
0 commit comments