Reproducing CT-RATE retrieval numbers

#6
by gings - opened

Hi, thanks a lot for releasing COLIPRI.

I am currently trying to reproduce the CT-RATE retrieval numbers:

  1. The CT-RATE validation split has 1,564 unique reports, but the paper evaluates on 1,493 report-scan pairs. Could you release the list of the 1,493 .nii.gz volume names you used?

  2. If possible, could you also release the CT-RATE retrieval evaluation script? I implemented one but it is currently ~10 percentage points below the paper on R@10 and I suspect preprocessing/report-formatting/other differences that are difficult to figure out from the current repo and the paper.

Thanks for considering and best regards,

Microsoft org
edited May 18

Hey @gings .

  1. If you exclude the studies that are not chest CTs (no_chest_valid.txt), you go from 1564 to 1551.
  2. If you the drop duplicates (by Findings_EN), you go from 1551 to 1493.

I'll see if we can help with the code for retrieval, but the implementation was based on this code from CXR-CLIP.

Microsoft org
edited May 18

Found it. Hope this helps:

    def evaluate_retrieval(self, overwrite: bool = False) -> dict[str, float]:
        vision_embeds = self.calculate_or_load_vision_embeddings(overwrite)
        vision_embeddings = vision_embeds["image_embeddings"].squeeze(dim=-1)
        normed_vision_embeddings = vision_embeddings / torch.norm(vision_embeddings, dim=-1, keepdim=True)  # [N, 768]

        findings = vision_embeds["findings"]

        language_embeddings = embed_text_prompts(
            findings,
            text_encoder=self.text_encoder,
            text_pooler=self.text_pooler,
            text_tokenizer=self.text_tokenizer,
        )
        language_embeddings = torch.cat(language_embeddings, dim=0).squeeze(dim=-1)
        normed_language_embeddings = (
            language_embeddings / torch.norm(language_embeddings, dim=-1, keepdim=True)
        ).cpu()

        temp_scaled_vision_embeddings = normed_vision_embeddings / self.temperature
        temp_scaled_language_embeddings = normed_language_embeddings / self.temperature
        test_retrieval_metrics = retrieval_image_text(
            image_embeddings=temp_scaled_vision_embeddings.cpu(),
            text_embeddings=temp_scaled_language_embeddings.cpu(),
            text_list=findings,
        )
        test_retrieval_report_image = retrieval_image_text(
            image_embeddings=temp_scaled_vision_embeddings.cpu(),
            text_embeddings=temp_scaled_language_embeddings.cpu(),
            text_list=findings,
            do_image_text_retrieval=False,
        )
        test_retrieval_metrics.update(test_retrieval_report_image)
        save_json(test_retrieval_metrics, os.path.join(self.output_folder, "test_retrieval_CT-RATE.json"))
import numpy as np
from sklearn import metrics

def retrieval_image_text(
    image_embeddings: np.ndarray,
    text_embeddings: np.ndarray,
    text_list: list = [],
    do_image_text_retrieval: bool = True,
):
    image_embeddings = np.array(image_embeddings)
    text_embeddings = np.array(text_embeddings)
    identical_text_set = []

    idx2label = {}
    identical_indexes = []
    for i, text in enumerate(text_list):
        if text not in identical_text_set:
            identical_text_set.append(text)
            identical_indexes.append(i)
            idx2label[i] = len(identical_text_set) - 1
        else:
            idx2label[i] = identical_text_set.index(text)

    identical_text_embedding = text_embeddings[identical_indexes]

    num_samples = image_embeddings.shape[0]
    n_text = len(identical_text_set)

    result = {}

    try:
        if do_image_text_retrieval:
            similarities = metrics.pairwise.cosine_similarity(image_embeddings, identical_text_embedding)  # n x m

            recall_dict: dict[int, float] = {1: 0, 5: 0, 10: 0}
            map_dict: dict[int, float] = {1: 0, 5: 0, 10: 0}
            ranks = []
            for idx in range(num_samples):
                label = idx2label[idx]
                similarity = similarities[idx]
                similarity_args = similarity.argsort()

                # rank of the paired text
                rank = n_text - np.argwhere(similarity_args == label).ravel()[0]
                ranks.append(rank)

                reciprocal_rank = 1 / float(rank)
                for k in recall_dict:
                    if rank <= k:
                        recall_dict[k] += 1
                        map_dict[k] += reciprocal_rank
        else:
            # Text-to-image retrieval: for each unique text, find matching images
            similarities = metrics.pairwise.cosine_similarity(
                identical_text_embedding, image_embeddings
            )  # n_unique_texts x n_images

            recall_dict: dict[int, float] = {1: 0, 5: 0, 10: 0}
            map_dict: dict[int, float] = {1: 0, 5: 0, 10: 0}
            ranks = []

            # Create reverse mapping: unique_text_idx -> list of image indices with that text
            text_to_image_indices = {}
            for img_idx, unique_text_idx in idx2label.items():
                if unique_text_idx not in text_to_image_indices:
                    text_to_image_indices[unique_text_idx] = []
                text_to_image_indices[unique_text_idx].append(img_idx)

            for text_idx in range(n_text):
                similarity = similarities[text_idx]
                similarity_args = similarity.argsort()[::-1]  # Sort in descending order for text-to-image

                # Find the best rank among all images with this text
                matching_image_indices = text_to_image_indices[text_idx]
                best_rank = float("inf")

                for img_idx in matching_image_indices:
                    rank = np.argwhere(similarity_args == img_idx).ravel()[0] + 1  # 1-indexed
                    best_rank = min(best_rank, rank)

                ranks.append(best_rank)
                reciprocal_rank = 1 / float(best_rank)

                for k in recall_dict:
                    if best_rank <= k:
                        recall_dict[k] += 1
                        map_dict[k] += reciprocal_rank
            num_samples = n_text

        ranks = np.array(ranks)
        result.update({f"RecallAt{k}": v / num_samples for k, v in recall_dict.items()})
        result.update({f"MeanAveragePrecisionAt{k}": v / num_samples for k, v in map_dict.items()})
        result.update({"MeanReciprocalRank": float(np.mean(1 / ranks))})
        result.update({"MeanRank": float(np.mean(ranks))})
        result.update({"MedianRank": float(np.median(ranks))})
        result.update({"NumSamples": num_samples})
    except ValueError:
        recall_dict = {1: np.nan, 5: np.nan, 10: np.nan}
        map_dict = {1: np.nan, 5: np.nan, 10: np.nan}
        result.update({f"RecallAt{k}": v / num_samples for k, v in recall_dict.items()})
        result.update({f"MeanAveragePrecisionAt{k}": v / num_samples for k, v in map_dict.items()})
        result.update({"MeanReciprocalRank": np.nan})
        result.update({"MeanRank": float(np.nan)})
        result.update({"MedianRank": float(np.nan)})
        result.update({"NumSamples": np.nan})

    if do_image_text_retrieval:
        result = {f"ImageText_{k}": v for k, v in result.items()}
    else:
        result = {f"TextImage_{k}": v for k, v in result.items()}

    return result

Hi, thanks for your quick reply,

I am still getting worse numbers than reported in the paper. I have created a standalone script to reproduce the problem:

https://gist.github.com/simon-ging/18720ce09be4943d738264b40ad8d839

Script output:

$ python colipri_retrieval_ctrate.py
nohead: dropped 13 studies, 3039 -> 2975 volumes
dedup: 1493 studies after smallest-spacing + identical-text dedup
Eval set: 1493 studies, 1493 unique reports
Loading weights: 100%
image: 100%
text: 100%
#################### TextImage ####################
R@1:  11.45
R@5:  27.86
R@10:  38.45
MedR:  23.00
MeanR:  92.52
MRR: 0.1994
N: 1493
#################### ImageText ####################
R@1:   6.50
R@5:  21.10
R@10:  30.54
MedR:  35.00
MeanR: 117.02
MRR: 0.1446
N: 1493

Expected: 46.01 / 40.10 R@10 for report-to-image / image-to-report.

If possible please check it and let me know how to reproduce the original paper results.

Best,

Microsoft org

Hi @gings , thanks again for the detailed write-up and the standalone script.

The orientation in this repo might have been incorrect. I've released a new version (v0.1.2).
Also, you need to use the images with the smallest size, not spacing (which would be the opposite). But I don't think that should make a big difference.

Could you please upgrade colipri, replace XYSpacing with Columns, run your script and report your results (if you have time, you could do both separately to avoid confounding factors)?

Hi, thanks for your help,

I ran it as described and interestingly enough, now the results are "too good"

Expected: 46.01 / 40.10 R@10 for report-to-image / image-to-report.

I got: 60.82 / 61.89 after both changes.

Details:

With upgraded colipri only

nohead: dropped 13 studies, 3039 -> 2975 volumes
dedup: 1493 studies after smallest-spacing + identical-text dedup
Eval set: 1493 studies, 1493 unique reports
#################### TextImage ####################
R@1:  24.38
R@5:  47.56
R@10:  57.87
MedR:   6.00
MeanR:  48.25
MRR: 0.3570
N: 1493
#################### ImageText ####################
R@1:  23.17
R@5:  47.15
R@10:  57.60
MedR:   6.00
MeanR:  47.03
MRR: 0.3454
N: 1493

With upgraded colipri + selecting smallest volume

nohead: dropped 13 studies, 3039 -> 2975 volumes
dedup: 1493 studies left
Eval set: 1493 studies, 1493 unique reports
#################### TextImage ####################
R@1:  26.46
R@5:  50.77
R@10:  60.82
MedR:   5.00
MeanR:  42.81
MRR: 0.3810
N: 1493
#################### ImageText ####################
R@1:  25.72
R@5:  49.83
R@10:  61.89
MedR:   6.00
MeanR:  40.16
MRR: 0.3735
N: 1493

I can only speculate, but maybe the checkpoint has seen the val set already? Or the model is really that good but was reported in a wrong way in the paper / another checkpoint was evaluated than uploaded? Something could still be off with my script?

Once I have some more time I will run some more tasks (zeroshot disease classifcation or retrieval on different datasets).

Best,

Simon

Microsoft org

Yeah I got those numbers too but wanted you to try in case I was doing something wrong. The official validation set was never seen by the encoder during training. The uploaded checkpoint is COLIPRI-CRM in the paper. I suspect we used the wrong orientation in the paper and I need to submit the camera-ready version today 🤡.

Anyway, I'm glad I could help.

Sign up or log in to comment