| import gradio as gr |
| from transformers import AutoProcessor, AutoModelForCausalLM, pipeline |
| import torch |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| checkpoint1 = "microsoft/git-base" |
| processor = AutoProcessor.from_pretrained(checkpoint1) |
| |
| model1 = AutoModelForCausalLM.from_pretrained(checkpoint1) |
|
|
| |
| checkpoint2 = "wangjin2000/git-base-finetune" |
| |
| model2 = AutoModelForCausalLM.from_pretrained(checkpoint2) |
|
|
| |
| |
| en_zh_translator = pipeline("translation",model="liam168/trans-opus-mt-en-zh") |
|
|
| def img2cap_com(image): |
| input1 = processor(images=image, return_tensors="pt").to(device) |
| pixel_values1 = input1.pixel_values |
| generated_id1 = model1.generate(pixel_values=pixel_values1, max_length=50) |
| generated_caption1 = processor.batch_decode(generated_id1, skip_special_tokens=True)[0] |
| |
| translated_caption1 = [generated_caption1, en_zh_translator(generated_caption1)] |
| |
| input2 = processor(images=image, return_tensors="pt").to(device) |
| pixel_values2 = input2.pixel_values |
| generated_id2 = model2.generate(pixel_values=pixel_values2, max_length=50) |
| generated_caption2 = processor.batch_decode(generated_id2, skip_special_tokens=True)[0] |
| translated_caption2 = [generated_caption2, en_zh_translator(generated_caption2)] |
| |
| return translated_caption1,translated_caption2 |
|
|
| inputs = [ |
| gr.Image(type="pil", label="Original Image") |
| ] |
|
|
| outputs = [ |
| gr.Textbox(label="Caption from pre-trained model"), |
| gr.Textbox(label="Caption from fine-tuned model"), |
| ] |
|
|
| title = "Image Captioning using Pre-trained and Fine-tuned Model" |
| description = "GIT-base is used to generate Image Caption for the uploaded image." |
|
|
| examples = [ |
| ["Image1.png"], |
| ["Image2.png"], |
| ["Image3.png"], |
| ["Image4.png"], |
| ["Image5.png"], |
| ["Image6.png"] |
| ] |
|
|
| gr.Interface( |
| img2cap_com, |
| inputs, |
| outputs, |
| title=title, |
| description=description, |
| examples=examples, |
| theme="huggingface", |
| ).launch() |
|
|