mirror of
https://github.com/microsoft/TRELLIS.2.git
synced 2026-04-02 02:27:08 -04:00
152 lines
5.1 KiB
Python
152 lines
5.1 KiB
Python
import gradio as gr
|
|
|
|
import os
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|
from datetime import datetime
|
|
import shutil
|
|
from typing import *
|
|
import torch
|
|
import numpy as np
|
|
import trimesh
|
|
from PIL import Image
|
|
from trellis2.pipelines import Trellis2TexturingPipeline
|
|
|
|
|
|
MAX_SEED = np.iinfo(np.int32).max
|
|
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
|
|
|
|
|
def start_session(req: gr.Request):
|
|
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
|
os.makedirs(user_dir, exist_ok=True)
|
|
|
|
|
|
def end_session(req: gr.Request):
|
|
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
|
shutil.rmtree(user_dir)
|
|
|
|
|
|
def preprocess_image(image: Image.Image) -> Image.Image:
|
|
"""
|
|
Preprocess the input image.
|
|
|
|
Args:
|
|
image (Image.Image): The input image.
|
|
|
|
Returns:
|
|
Image.Image: The preprocessed image.
|
|
"""
|
|
processed_image = pipeline.preprocess_image(image)
|
|
return processed_image
|
|
|
|
|
|
def get_seed(randomize_seed: bool, seed: int) -> int:
|
|
"""
|
|
Get the random seed.
|
|
"""
|
|
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
|
|
|
|
|
|
def shapeimage_to_tex(
|
|
mesh_file: str,
|
|
image: Image.Image,
|
|
seed: int,
|
|
resolution: str,
|
|
texture_size: int,
|
|
tex_slat_guidance_strength: float,
|
|
tex_slat_guidance_rescale: float,
|
|
tex_slat_sampling_steps: int,
|
|
tex_slat_rescale_t: float,
|
|
req: gr.Request,
|
|
progress=gr.Progress(track_tqdm=True),
|
|
) -> str:
|
|
mesh = trimesh.load(mesh_file)
|
|
if isinstance(mesh, trimesh.Scene):
|
|
mesh = mesh.to_mesh()
|
|
output = pipeline.run(
|
|
mesh,
|
|
image,
|
|
seed=seed,
|
|
preprocess_image=False,
|
|
tex_slat_sampler_params={
|
|
"steps": tex_slat_sampling_steps,
|
|
"guidance_strength": tex_slat_guidance_strength,
|
|
"guidance_rescale": tex_slat_guidance_rescale,
|
|
"rescale_t": tex_slat_rescale_t,
|
|
},
|
|
resolution=int(resolution),
|
|
texture_size=texture_size,
|
|
)
|
|
now = datetime.now()
|
|
timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
|
|
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
|
os.makedirs(user_dir, exist_ok=True)
|
|
glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
|
|
output.export(glb_path, extension_webp=True)
|
|
torch.cuda.empty_cache()
|
|
return glb_path, glb_path
|
|
|
|
|
|
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
gr.Markdown("""
|
|
## Texturing a mesh with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
|
|
* Upload a mesh and corresponding reference image (preferably with an alpha-masked foreground object) and click Generate to create a textured 3D asset.
|
|
""")
|
|
|
|
with gr.Row():
|
|
with gr.Column(scale=1, min_width=360):
|
|
mesh_file = gr.File(label="Upload Mesh", file_types=[".ply", ".obj", ".glb", ".gltf"], file_count="single")
|
|
image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400)
|
|
|
|
resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
|
|
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
|
|
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
|
texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
|
|
|
|
generate_btn = gr.Button("Generate")
|
|
|
|
with gr.Accordion(label="Advanced Settings", open=False):
|
|
with gr.Row():
|
|
tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
|
|
tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
|
|
tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
|
|
tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
|
|
|
|
with gr.Column(scale=10):
|
|
glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
|
|
download_btn = gr.DownloadButton(label="Download GLB")
|
|
|
|
|
|
# Handlers
|
|
demo.load(start_session)
|
|
demo.unload(end_session)
|
|
|
|
image_prompt.upload(
|
|
preprocess_image,
|
|
inputs=[image_prompt],
|
|
outputs=[image_prompt],
|
|
)
|
|
|
|
generate_btn.click(
|
|
get_seed,
|
|
inputs=[randomize_seed, seed],
|
|
outputs=[seed],
|
|
).then(
|
|
shapeimage_to_tex,
|
|
inputs=[
|
|
mesh_file, image_prompt, seed, resolution, texture_size,
|
|
tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
|
|
],
|
|
outputs=[glb_output, download_btn],
|
|
)
|
|
|
|
|
|
# Launch the Gradio app
|
|
if __name__ == "__main__":
|
|
os.makedirs(TMP_DIR, exist_ok=True)
|
|
|
|
pipeline = Trellis2TexturingPipeline.from_pretrained('microsoft/TRELLIS.2-4B', config_file="texturing_pipeline.json")
|
|
pipeline.cuda()
|
|
|
|
demo.launch()
|