| import argparse |
| import os |
| from contextlib import nullcontext |
|
|
| import torch |
| from PIL import Image |
| from tqdm import tqdm |
| from transparent_background import Remover |
|
|
| from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE |
| from spar3d.system import SPAR3D |
| from spar3d.utils import foreground_crop, get_device, remove_background |
|
|
|
|
| def check_positive(value): |
| ivalue = int(value) |
| if ivalue <= 0: |
| raise argparse.ArgumentTypeError("%s is an invalid positive int value" % value) |
| return ivalue |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "image", type=str, nargs="+", help="Path to input image(s) or folder." |
| ) |
| parser.add_argument( |
| "--device", |
| default=get_device(), |
| type=str, |
| help=f"Device to use. If no CUDA/MPS-compatible device is found, the baking will fail. Default: '{get_device()}'", |
| ) |
| parser.add_argument( |
| "--pretrained-model", |
| default="stabilityai/stable-point-aware-3d", |
| type=str, |
| help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/stable-point-aware-3d'", |
| ) |
| parser.add_argument( |
| "--foreground-ratio", |
| default=1.3, |
| type=float, |
| help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| default="output/", |
| type=str, |
| help="Output directory to save the results. Default: 'output/'", |
| ) |
| parser.add_argument( |
| "--texture-resolution", |
| default=1024, |
| type=int, |
| help="Texture atlas resolution. Default: 1024", |
| ) |
| parser.add_argument( |
| "--low-vram-mode", |
| action="store_true", |
| help=( |
| "Use low VRAM mode. SPAR3D consumes 10.5GB of VRAM by default. " |
| "This mode will reduce the VRAM consumption to roughly 7GB but in exchange " |
| "the model will be slower. Default: False" |
| ), |
| ) |
|
|
| remesh_choices = ["none"] |
| if TRIANGLE_REMESH_AVAILABLE: |
| remesh_choices.append("triangle") |
| if QUAD_REMESH_AVAILABLE: |
| remesh_choices.append("quad") |
| parser.add_argument( |
| "--remesh_option", |
| choices=remesh_choices, |
| default="none", |
| help="Remeshing option", |
| ) |
| if TRIANGLE_REMESH_AVAILABLE or QUAD_REMESH_AVAILABLE: |
| parser.add_argument( |
| "--reduction_count_type", |
| choices=["keep", "vertex", "faces"], |
| default="keep", |
| help="Vertex count type", |
| ) |
| parser.add_argument( |
| "--target_count", |
| type=check_positive, |
| help="Selected target count.", |
| default=2000, |
| ) |
| parser.add_argument( |
| "--batch_size", default=1, type=int, help="Batch size for inference" |
| ) |
| args = parser.parse_args() |
|
|
| |
| devices = ["cuda", "mps", "cpu"] |
| if not any(args.device in device for device in devices): |
| raise ValueError("Invalid device. Use cuda, mps or cpu") |
|
|
| output_dir = args.output_dir |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| device = args.device |
| if not (torch.cuda.is_available() or torch.backends.mps.is_available()): |
| device = "cpu" |
|
|
| print("Device used: ", device) |
|
|
| model = SPAR3D.from_pretrained( |
| args.pretrained_model, |
| config_name="config.yaml", |
| weight_name="model.safetensors", |
| low_vram_mode=args.low_vram_mode, |
| ) |
| model.to(device) |
| model.eval() |
|
|
| bg_remover = Remover(device=device) |
| images = [] |
| idx = 0 |
| for image_path in args.image: |
|
|
| def handle_image(image_path, idx): |
| image = remove_background( |
| Image.open(image_path).convert("RGBA"), bg_remover |
| ) |
| image = foreground_crop(image, args.foreground_ratio) |
| os.makedirs(os.path.join(output_dir, str(idx)), exist_ok=True) |
| image.save(os.path.join(output_dir, str(idx), "input.png")) |
| images.append(image) |
|
|
| if os.path.isdir(image_path): |
| image_paths = [ |
| os.path.join(image_path, f) |
| for f in os.listdir(image_path) |
| if f.endswith((".png", ".jpg", ".jpeg")) |
| ] |
| for image_path in image_paths: |
| handle_image(image_path, idx) |
| idx += 1 |
| else: |
| handle_image(image_path, idx) |
| idx += 1 |
|
|
| vertex_count = ( |
| -1 |
| if args.reduction_count_type == "keep" |
| else ( |
| args.target_count |
| if args.reduction_count_type == "vertex" |
| else args.target_count // 2 |
| ) |
| ) |
|
|
| for i in tqdm(range(0, len(images), args.batch_size)): |
| image = images[i : i + args.batch_size] |
| if torch.cuda.is_available(): |
| torch.cuda.reset_peak_memory_stats() |
| with torch.no_grad(): |
| with ( |
| torch.autocast(device_type=device, dtype=torch.bfloat16) |
| if "cuda" in device |
| else nullcontext() |
| ): |
| mesh, glob_dict = model.run_image( |
| image, |
| bake_resolution=args.texture_resolution, |
| remesh=args.remesh_option, |
| vertex_count=vertex_count, |
| return_points=True, |
| ) |
| if torch.cuda.is_available(): |
| print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB") |
| elif torch.backends.mps.is_available(): |
| print( |
| "Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB" |
| ) |
|
|
| if len(image) == 1: |
| out_mesh_path = os.path.join(output_dir, str(i), "mesh.glb") |
| mesh.export(out_mesh_path, include_normals=True) |
| out_points_path = os.path.join(output_dir, str(i), "points.ply") |
| glob_dict["point_clouds"][0].export(out_points_path) |
| else: |
| for j in range(len(mesh)): |
| out_mesh_path = os.path.join(output_dir, str(i + j), "mesh.glb") |
| mesh[j].export(out_mesh_path, include_normals=True) |
| out_points_path = os.path.join(output_dir, str(i + j), "points.ply") |
| glob_dict["point_clouds"][j].export(out_points_path) |
|
|