Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some general points about sample handling #459

Open
jwa7 opened this issue Jan 28, 2025 · 2 comments
Open

Some general points about sample handling #459

jwa7 opened this issue Jan 28, 2025 · 2 comments
Labels
Discussion Issues to be discussed by the contributors Infrastructure: Miscellaneous General infrastructure issues SOAP BPNN SOAP BPNN experimental architecture

Comments

@jwa7
Copy link
Member

jwa7 commented Jan 28, 2025

1. System indexing

Currently, a prediction by SoapBpnn produces a TensorMap where the system indices run from 0, ..., N_batch. These indices do not match the actual system indices of the systems as they are defined in the dataset.

This is not especially problematic, but I noticed that the metadata checks in TensorMapLoss only checks the length of the samples as opposed to full equivalence, where for properties and components this is checked:

            if block_1.properties != block_2.properties:
                raise ValueError(
                    "TensorMapLoss requires the two TensorMaps to have the same "
                    "properties."
                )
            if block_1.components != block_2.components:
                raise ValueError(
                    "TensorMapLoss requires the two TensorMaps to have the same "
                    "components."
                )
            if len(block_1.samples) != len(block_2.samples):
                raise ValueError(
                    "TensorMapLoss requires the two TensorMaps "
                    "to have the same number of samples."
                )

To have a full equivalence, one could simply re-index the "system" sample dimension of the model predictions to match the indices of the systems in the batch. This would not be particularly heavy computationally, but might break things if the fomer convention is assumed in other places in the code.

2. Selected samples/atoms

I originally noticed the above problem when attempting to make predictions on a subset of samples using a SOAP-BPNN model for spherical targets. Currently, the forward methods for classes SoapBpnn and TensorBasis have an argument selected_atoms, which is passed to the SOAP calculator as selected_samples. However, this parameter is not currently exposed to the user.

This also links to issue #458 . If the spherical target is an electron density whereby independent models per block are initialised, one could in principle compute SOAP descriptors with different hypers per atom type (and utilising the selected_samples functionality of featomic to only compute, for instance, descriptors for nitrogen) that are then passed through different models with separate weights, before being joined at the output level to form a prediction on the full basis.

3. Dataloader joining

The dataloaders used in the SOAP-BPNN trainer module uses the metatensor-learn Dataloader class. These accept keyword arguments join_kwargs that are passed to the metatensor.join operation when minibatches are compiled. Currently, these join_kwargs are not exposed to the user, nor are any passed by default.

For instance, model predictions carry around the sample dimension "tensor" as a consequence of the parameter remove_tensor_name=True not being passed to the dataloader in join_kwargs. This also relates to point 1 above, as the samples metadata between target and prediction cannot be directly compared.

Further to this, and I will use again the example of the electron density expanded on a basis, the different_keys="union" parameter of metatensor.join is a useful argument to have control over. Often, the basis set definition between systems in a minibatch is not necessarily consistent if the atom types present in the systems is different. In this case, in order to compile a minibatch between such system, a union of the keys (assumed as for instance ["o3_lambda", "o3_sigma", "center_type"]) is required.

@jwa7 jwa7 added Discussion Issues to be discussed by the contributors Infrastructure: Miscellaneous General infrastructure issues SOAP BPNN SOAP BPNN experimental architecture labels Jan 28, 2025
@Luthaf
Copy link
Member

Luthaf commented Jan 28, 2025

I agree that exposing some of these could be exposed to the users, but we should not expose them in the same way they are exposed in the Python API. We should not expect users of metatrain to even know about metatensor-learn. In general, I think we should expose things in the metatrain input in very high-level terms, or make a decision on behalf of the user.

For example, remove_tensor_name=True could be always given, and not an option. Basically, I'm trying to limit the number of options we give to the user so to not overload them with choices and things to understand before being able to train a model.

@jwa7
Copy link
Member Author

jwa7 commented Jan 28, 2025

Yes I agree with the spirit limiting options where possible. Perhaps the solution is to define sensible internal default based on the use case.

For example, the default minibatch collating behaviour for generic targets could be what it currently is, i.e. different_keys="error", but for spherical targets it could be different_keys="union", under the assumption that this type of target could be used for basis set definitions that depend on the atom type.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Discussion Issues to be discussed by the contributors Infrastructure: Miscellaneous General infrastructure issues SOAP BPNN SOAP BPNN experimental architecture
Projects
None yet
Development

No branches or pull requests

2 participants