找回密码
 立即注册

扫一扫,访问微社区

阿拉丁灯神 该用户已被删除
发表于 2019-1-23 23:36:34
2540
from PIL import Image
import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit
import tensorrt as trt
trt_logger = trt.Logger(trt.Logger.WARNING)
def build_engine(model_file):
        builder = trt.Builder(trt_logger)
        network = builder.create_network()
        builder.fp16_mode = True
        parser = trt.UffParser()
        builder.max_workspace_size = (1<<30)
        parser.register_input("input_1",(1,28,28))
        parser.register_output("dense_1/Softmax")
        parser.parse(model_file,network)
        return builder.build_cuda_engine(network)
def do_inference(context,bindings,inputs,outputs,stream):
        cuda.memcpy_htod_async(inputs[1],inputs[0],stream)
        context.execute_async(batch_size=1,bindings=bindings,stream_handle=stream.handle)
        cuda.memcpy_dtoh_async(outputs[0],outputs[1],stream)
        stream.synchronize()
        return outputs[0]
def main():
        test_num=6
        data_path = str(test_num)+".pgm"
        model="lenet.uff"
        engine = build_engine(model)
        stream = cuda.Stream()
        bindings = []
        for binding in engine:
                size = trt.volume(engine.get_binding_shape(binding))*engine.max_batch_size
                dtype = trt.nptype(engine.get_binding_dtype(binding))
                host = cuda.pagelocked_empty(size,dtype)
                device = cuda.mem_alloc(host.nbytes)
                bindings.append(int(device))
                if(engine.binding_is_input(binding)):
                        inputs = (host,device)
                else:
                        outputs = (host,device)
        context = engine.create_execution_context()
        img = np.array(Image.open(data_path)).ravel()
        np.copyto(inputs[0],1.0-img/255)
        result = do_inference(context,bindings,inputs,outputs,stream)
        prediction = np.argmax(result)
        print("test_num: "+str(test_num))
        print("prediction: "+str(prediction))
if __name__ == '__main__':
        main()

使用道具 举报 回复
发新帖
您需要登录后才可以回帖 登录 | 立即注册

zzczczxczxczx