Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support model in_chans not equal to pre-trained weights in_chans #2289

Open
adamjstewart opened this issue Sep 9, 2024 · 1 comment
Open
Assignees
Labels
good first issue A good issue for a new contributor to work on models Models and pretrained weights trainers PyTorch Lightning trainers

Comments

@adamjstewart
Copy link
Collaborator

Summary

If a user specifies in_chans and weights, and weights.meta['in_chans'] differs from in_chans, the user-specified argument should take precedence and weights should be repeated, similar to how timm handles pre-trained weights.

Rationale

When working on change detection, it is common to take two images and stack them along the channel dimension. However, this makes it impossible to use our pre-trained weights. Ideally, I would like to support something like:

from torchgeo.models import ResNet50_Weights, resnet50

model = resnet50(in_chans=4, weights=ResNet50_Weights.SENTINEL1_ALL_MOCO)

Here, the weights have 2 channels (HH and HV), while the dataset and model will have 4 channels (HH, HV, HH, HV).

Implementation

https://timm.fast.ai/models#Case-2:-When-the-number-of-input-channels-is-not-1 describes the implementation that timm uses. This can be imported as:

from timm.models.helpers import load_pretrained

We should make use of this in all of our model definitions instead of model.load_state_dict.

Alternatives

There is some ongoing work to add a ChangeDetectionTask that may split each image into a separate sample key. However, there will always be models that require images stacked along the channel dimension, so I don't think we can avoid supporting this use case.

Additional information

No response

@adamjstewart adamjstewart added models Models and pretrained weights good first issue A good issue for a new contributor to work on trainers PyTorch Lightning trainers labels Sep 9, 2024
@keves1
Copy link

keves1 commented Sep 17, 2024

I'm interested in contributing to this, could I have this assigned to me? I looked into the load_pretrained method and found that it will only copy the weights if in_chans of the weights is 3, otherwise it uses random init for the first conv layer (see link to adapt_input_conv() which is called by load_pretrained())

https://github.com/huggingface/pytorch-image-models/blob/ee5b1e8217134e9f016a0086b793c34abb721216/timm/models/_manipulate.py#L256-L278

So this would need to be adapted since there are a variable number of input channels for the weights we would use.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue A good issue for a new contributor to work on models Models and pretrained weights trainers PyTorch Lightning trainers
Projects
None yet
Development

No branches or pull requests

2 participants