| import os |
| import subprocess |
| import torch |
| import requests |
| from PIL import Image |
| from io import BytesIO |
| from test import just_get_sd_mask |
|
|
| print(os.listdir('/usr/local/')) |
| print(torch.version.cuda) |
|
|
| class EndpointHandler(): |
| def __init__(self, path="."): |
| pass |
|
|
| def __call__(self, data): |
| mask_pil = just_get_sd_mask(Image.open("assets/demo1.jpg"), "bear", 10) |
|
|
| if mask_pil.mode != 'RGB': |
| mask_pil = mask_pil.convert('RGB') |
|
|
| |
| img_byte_arr = BytesIO() |
| mask_pil.save(img_byte_arr, format='JPEG') |
| img_byte_arr = img_byte_arr.getvalue() |
|
|
| |
| response = requests.post("https://file.io/", files={"file": img_byte_arr}) |
| url = response.json().get('link') |
|
|
| return {"url": url} |