| import os |
| import random |
| import tempfile |
| import time |
| import zipfile |
| from contextlib import nullcontext |
| from functools import lru_cache |
| from typing import Any |
|
|
| import cv2 |
| import gradio as gr |
| import numpy as np |
| import torch |
| import trimesh |
| from gradio_litmodel3d import LitModel3D |
| from gradio_pointcloudeditor import PointCloudEditor |
| from PIL import Image |
| from transparent_background import Remover |
|
|
| os.system("USE_CUDA=1 pip install -vv --no-build-isolation ./texture_baker ./uv_unwrapper") |
| os.system("pip install ./deps/pynim-0.0.3-cp310-cp310-linux_x86_64.whl") |
|
|
| import spar3d.utils as spar3d_utils |
| from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE |
| from spar3d.system import SPAR3D |
|
|
| os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.environ.get("TMPDIR", "/tmp"), "gradio") |
|
|
| bg_remover = Remover() |
|
|
| COND_WIDTH = 512 |
| COND_HEIGHT = 512 |
| COND_DISTANCE = 2.2 |
| COND_FOVY = 0.591627 |
| BACKGROUND_COLOR = [0.5, 0.5, 0.5] |
|
|
| |
| c2w_cond = spar3d_utils.default_cond_c2w(COND_DISTANCE) |
| intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad( |
| COND_FOVY, COND_HEIGHT, COND_WIDTH |
| ) |
|
|
| generated_files = [] |
|
|
| |
| if os.path.exists(os.environ["GRADIO_TEMP_DIR"]): |
| print(f"Deleting {os.environ['GRADIO_TEMP_DIR']}") |
| import shutil |
|
|
| shutil.rmtree(os.environ["GRADIO_TEMP_DIR"]) |
|
|
| device = spar3d_utils.get_device() |
|
|
| model = SPAR3D.from_pretrained( |
| "stabilityai/stable-point-aware-3d", |
| config_name="config.yaml", |
| weight_name="model.safetensors", |
| ) |
| model.eval() |
| model = model.to(device) |
|
|
| example_files = [ |
| os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples") |
| ] |
|
|
| def create_zip_file(glb_file, pc_file, illumination_file): |
| if not all([glb_file, pc_file, illumination_file]): |
| return None |
|
|
| |
| temp_dir = tempfile.mkdtemp() |
| zip_path = os.path.join(temp_dir, "spar3d_output.zip") |
|
|
| with zipfile.ZipFile(zip_path, "w") as zipf: |
| zipf.write(glb_file, "mesh.glb") |
| zipf.write(pc_file, "points.ply") |
| zipf.write(illumination_file, "illumination.hdr") |
|
|
| generated_files.append(zip_path) |
| return zip_path |
|
|
| def forward_model( |
| batch, |
| system, |
| guidance_scale=3.0, |
| seed=0, |
| device="cuda", |
| remesh_option="none", |
| vertex_count=-1, |
| texture_resolution=1024, |
| ): |
| batch_size = batch["rgb_cond"].shape[0] |
|
|
| |
| |
| random.seed(seed) |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| cond_tokens = system.forward_pdiff_cond(batch) |
|
|
| if "pc_cond" not in batch: |
| sample_iter = system.sampler.sample_batch_progressive( |
| batch_size, |
| cond_tokens, |
| guidance_scale=guidance_scale, |
| device=device, |
| ) |
| for x in sample_iter: |
| samples = x["xstart"] |
| batch["pc_cond"] = samples.permute(0, 2, 1).float() |
| batch["pc_cond"] = spar3d_utils.normalize_pc_bbox(batch["pc_cond"]) |
|
|
| |
| batch["pc_cond"] = batch["pc_cond"][ |
| :, torch.randperm(batch["pc_cond"].shape[1])[:512] |
| ] |
|
|
| |
| xyz = batch["pc_cond"][0, :, :3].cpu().numpy() |
| color_rgb = (batch["pc_cond"][0, :, 3:6] * 255).cpu().numpy().astype(np.uint8) |
| pc_rgb_trimesh = trimesh.PointCloud(vertices=xyz, colors=color_rgb) |
|
|
| |
| trimesh_mesh, _glob_dict = model.generate_mesh( |
| batch, |
| texture_resolution, |
| remesh=remesh_option, |
| vertex_count=vertex_count, |
| estimate_illumination=True, |
| ) |
| trimesh_mesh = trimesh_mesh[0] |
| illumination = _glob_dict["illumination"] |
|
|
| return trimesh_mesh, pc_rgb_trimesh, illumination.cpu().detach().numpy()[0] |
|
|
| def process_model_run( |
| fr_res, |
| guidance_scale, |
| random_seed, |
| pc_cond, |
| remesh_option, |
| vertex_count_type, |
| vertex_count, |
| texture_resolution, |
| ): |
| start = time.time() |
| with torch.no_grad(): |
| with ( |
| torch.autocast(device_type=device, dtype=torch.bfloat16) |
| if "cuda" in device |
| else nullcontext() |
| ): |
| model_batch = create_batch(fr_res) |
| model_batch = {k: v.to(device) for k, v in model_batch.items()} |
|
|
| trimesh_mesh, trimesh_pc, illumination_map = forward_model( |
| model_batch, |
| model, |
| guidance_scale=guidance_scale, |
| seed=random_seed, |
| device="cuda", |
| remesh_option=remesh_option.lower(), |
| vertex_count=vertex_count, |
| texture_resolution=texture_resolution, |
| ) |
|
|
| |
| temp_dir = tempfile.mkdtemp() |
| tmp_file = os.path.join(temp_dir, "mesh.glb") |
|
|
| trimesh_mesh.export(tmp_file, file_type="glb", include_normals=True) |
| generated_files.append(tmp_file) |
|
|
| tmp_file_pc = os.path.join(temp_dir, "points.ply") |
| trimesh_pc.export(tmp_file_pc) |
| generated_files.append(tmp_file_pc) |
|
|
| tmp_file_illumination = os.path.join(temp_dir, "illumination.hdr") |
| cv2.imwrite(tmp_file_illumination, illumination_map) |
| generated_files.append(tmp_file_illumination) |
|
|
| print("Generation took:", time.time() - start, "s") |
|
|
| return tmp_file, tmp_file_pc, tmp_file_illumination, trimesh_pc |
|
|
| def create_batch(input_image: Image) -> dict[str, Any]: |
| img_cond = ( |
| torch.from_numpy( |
| np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32) |
| / 255.0 |
| ) |
| .float() |
| .clip(0, 1) |
| ) |
| mask_cond = img_cond[:, :, -1:] |
| rgb_cond = torch.lerp( |
| torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond |
| ) |
|
|
| batch_elem = { |
| "rgb_cond": rgb_cond, |
| "mask_cond": mask_cond, |
| "c2w_cond": c2w_cond.unsqueeze(0), |
| "intrinsic_cond": intrinsic.unsqueeze(0), |
| "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0), |
| } |
| |
| batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()} |
| return batched |
|
|
| def remove_background(input_image: Image) -> Image: |
| return bg_remover.process(input_image.convert("RGB")) |
|
|
| def auto_process(input_image): |
| if input_image is None: |
| return None, None, None, None |
| |
| |
| guidance_scale = 3.0 |
| random_seed = 0 |
| foreground_ratio = 1.3 |
| remesh_option = "None" |
| vertex_count_type = "Keep Vertex Count" |
| vertex_count = 2000 |
| texture_resolution = 1024 |
| no_crop = False |
| pc_cond = None |
|
|
| |
| rem_removed = remove_background(input_image) |
| fr_res = spar3d_utils.foreground_crop( |
| rem_removed, |
| crop_ratio=foreground_ratio, |
| newsize=(COND_WIDTH, COND_HEIGHT), |
| no_crop=no_crop, |
| ) |
|
|
| |
| glb_file, pc_file, illumination_file, pc_list = process_model_run( |
| fr_res, |
| guidance_scale, |
| random_seed, |
| pc_cond, |
| remesh_option, |
| vertex_count_type, |
| vertex_count, |
| texture_resolution, |
| ) |
|
|
| zip_file = create_zip_file(glb_file, pc_file, illumination_file) |
|
|
| return glb_file, illumination_file, zip_file, pc_list |
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown( |
| """ |
| # SPAR3D: Stable Point-Aware Reconstruction of 3D Objects from Single Images |
| Upload an image to generate a 3D model. |
| """ |
| ) |
| |
| with gr.Row(): |
| with gr.Column(): |
| input_img = gr.Image( |
| type="pil", |
| label="Upload Image", |
| sources=["upload", "click"], |
| image_mode="RGBA" |
| ) |
|
|
| with gr.Column(): |
| output_3d = LitModel3D( |
| label="3D Model", |
| clear_color=[0.0, 0.0, 0.0, 0.0], |
| tonemapping="aces", |
| contrast=1.0, |
| scale=1.0, |
| ) |
| download_all_btn = gr.File( |
| label="Download Model (ZIP)", |
| file_count="single", |
| visible=True |
| ) |
|
|
| input_img.upload( |
| auto_process, |
| inputs=[input_img], |
| outputs=[ |
| output_3d, |
| gr.State(), |
| download_all_btn, |
| gr.State(), |
| ], |
| ) |
|
|
| demo.queue().launch(share=False) |