Skip to content

进阶文档‐Fast API调用

wuziheng edited this page Sep 26, 2023 · 8 revisions

1、WebUI启动

启动Webui自带的fastapi需要在启动时加入--api。

python launch.py --api

其它参数按照自己的情况配置即可。此时会启动Webui自带的fastapi。

2、训练fastapi调用

创建post_train.py,填入下面的代码,其中http://0.0.0.0:7860 根据实际情况进行修改,需要填入服务器的ip。

import base64
import json
import os
import sys
from glob import glob

import cv2
import numpy as np
import requests

def decode_image_from_base64jpeg(base64_image):
    image_bytes = base64.b64decode(base64_image)
    np_arr = np.frombuffer(image_bytes, np.uint8)
    image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
    return image

def post(encoded_images):
    datas = json.dumps({
        "user_id"               : "your_user_id", 
        "sd_model_checkpoint"   : "Chilloutmix-Ni-pruned-fp16-fix.safetensors",
        "resolution"            : 512,
        "val_and_checkpointing_steps" : 100,
        "max_train_steps"       : 800,
        "steps_per_photos"      : 200,
        "train_batch_size"      : 1,
        "gradient_accumulation_steps" : 4,
        "dataloader_num_workers" : 16,
        "learning_rate"         : 1e-4,
        "rank"                  : 64,
        "network_alpha"         : 64,
        "instance_images"       : encoded_images, 
    })
    r = requests.post('http://0.0.0.0:7860/easyphoto/easyphoto_train_forward', data=datas, timeout=1500)
    data = r.content.decode('utf-8')
    return data

if __name__ == '__main__':
    img_dir     = sys.argv[1]
    img_list    = glob(os.path.join(img_dir, "*.jpg")) + glob(os.path.join(img_dir, "*.JPG"))
    encoded_images = []
    for idx, img_path in enumerate(img_list):
        with open(img_path, 'rb') as f:
            encoded_image = base64.b64encode(f.read()).decode('utf-8')
            encoded_images.append(encoded_image)
    outputs = post(encoded_images)
    outputs = json.loads(outputs)

    print(outputs)

然后使用如下sh代码进行api调用,your_data_dir是存放训练图片的路径,是一个文件夹:

python post_train.py your_data_dir

3、预测fastapi调用

创建post_infer.py,填入下面的代码,其中http://0.0.0.0:7860 根据实际情况进行修改,需要填入服务器的ip。

import base64
import json
import os
import sys
from glob import glob

import cv2
import numpy as np
import requests

def decode_image_from_base64jpeg(base64_image):
    image_bytes = base64.b64decode(base64_image)
    np_arr = np.frombuffer(image_bytes, np.uint8)
    image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
    return image

def post(encoded_image):
    datas = json.dumps({
        "user_ids"              : ["your_user_id"], 
        "sd_model_checkpoint"   : "Chilloutmix-Ni-pruned-fp16-fix.safetensors",
        "init_image"            : encoded_image, 

        "first_diffusion_steps"     : 50,
        "first_denoising_strength"  : 0.45,
        "second_diffusion_steps"    : 20,
        "second_denoising_strength" : 0.35,
        "seed"                      : 12345, 
        "crop_face_preprocess"      : True,

        "before_face_fusion_ratio"  : 0.5,
        "after_face_fusion_ratio"   : 0.5,
        "apply_face_fusion_before"  : True,
        "apply_face_fusion_after"   : True,

        "color_shift_middle"        : True,
        "color_shift_last"          : True,
        "super_resolution"          : True,
        "background_restore"        : False,
        "tabs"                      : 1
    })
    r = requests.post('http://0.0.0.0:7860/easyphoto/easyphoto_infer_forward', data=datas, timeout=1500)
    data = r.content.decode('utf-8')
    return data

if __name__ == '__main__':
    img_dir     = sys.argv[1]
    img_list    = glob(os.path.join(img_dir, "*.jpg")) + glob(os.path.join(img_dir, "*.JPG"))
    encoded_images = []
    for idx, img_path in enumerate(img_list):
        with open(img_path, 'rb') as f:
            encoded_image = base64.b64encode(f.read()).decode('utf-8')
            outputs = post(encoded_image)
            outputs = json.loads(outputs)
            image = decode_image_from_base64jpeg(outputs["outputs"][0])
            cv2.imwrite(str(idx) + ".jpg", image)

然后使用如下sh代码进行api调用,your_data_dir是存放模板图片的路径,是一个文件夹:

python post_infer.py your_data_dir

4、SDXLfastapi调用

创建post_infer.py,填入下面的代码,其中http://0.0.0.0:7860 根据实际情况进行修改,需要填入服务器的ip。

import base64
import json
import os
import sys
from glob import glob

import cv2
import numpy as np
import requests

def decode_image_from_base64jpeg(base64_image):
    image_bytes = base64.b64decode(base64_image)
    np_arr = np.frombuffer(image_bytes, np.uint8)
    image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
    return image

def post(encoded_image):
    datas = json.dumps({
        "user_ids"              : ["your_user_id"],
        "sd_model_checkpoint"   : "Chilloutmix-Ni-pruned-fp16-fix.safetensors",
        "init_image"            : encoded_image,

        "first_diffusion_steps"     : 50,
        "first_denoising_strength"  : 0.45,
        "second_diffusion_steps"    : 20,
        "second_denoising_strength" : 0.35,
        "seed"                      : 12345,
        "crop_face_preprocess"      : True,

        "before_face_fusion_ratio"  : 0.5,
        "after_face_fusion_ratio"   : 0.5,
        "apply_face_fusion_before"  : True,
        "apply_face_fusion_after"   : True,

        "color_shift_middle"        : True,
        "color_shift_last"          : True,
        "super_resolution"          : True,
        "background_restore"        : False,
        "sd_xl_input_prompt"        : "upper-body, look at viewer, one twenty years old girl, wear white shit, standing, in the garden, daytime, f32",
        "sd_xl_resolution"          : "(1024, 1024)",
        "tabs"                      : 3,
        "seed"                      : -1,
    })
    r = requests.post('http://0.0.0.0:7860/easyphoto/easyphoto_infer_forward', data=datas, timeout=1500)
    data = r.content.decode('utf-8')
    return data

if __name__ == '__main__':
    img_path     = sys.argv[1]
    if 1:
        with open(img_path, 'rb') as f:
            encoded_image = base64.b64encode(f.read()).decode('utf-8')
            outputs = post(encoded_image)
            outputs = json.loads(outputs)
            image = decode_image_from_base64jpeg(outputs["outputs"][0])

然后使用如下sh代码进行api调用,your_data_dir是存放模板图片的路径,是一个文件夹:

python post_infer.py input_image_path

也可以不使用任何垫图,直接传入 "init_image" : None, 也可以生成。