訂閱
糾錯
加入自媒體

SageMaker TensorFlow對象檢測模型

這篇文章描述了如何在Amazon SageMaker中使用TensorFlow對象檢測模型API來實現(xiàn)這一點。

首先,基于AWS示例筆記本,將解釋如何使用SageMaker端點在單個圖像上運行模型。對于較小的圖像,這種方法可行,但對于較大的圖像,我們會遇到問題。

為了解決這些問題,改用批處理轉(zhuǎn)換作業(yè)。

起點:使用SageMaker TensorFLow對象檢測API進行模型推斷

AWS提供了一些關(guān)于GitHub如何使用SageMaker的好例子。

使用此示例使用TensorFlow對象檢測API對對象檢測模型進行預(yù)測:

image.png

將模型部署為端點時,可以通過調(diào)用端點,使用該模型一次推斷一個圖像。此代碼取自示例筆記本,顯示了如何定義TensorFlowModel并將其部署為模型端點:

import cv2

import sagemaker

from sagemaker.utils import name_from_base

from sagemaker.tensorflow import TensorFlowModel

role = sagemaker.get_execution_role()

model_artefact = '<your-model-s3-path>'

model_endpoint = TensorFlowModel(

   name=name_from_base('tf2-object-detection'),

   model_data=model_artefact,

   role=role,

   framework_version='2.2',



predictor = model_endpoint.deploy(initial_instance_count=1, instance_type='ml.m5.large')

然后,將圖像加載為NumPy數(shù)組,并將其解析為列表,以便將其傳遞給端點:

def image_file_to_tensor(path):

   cv_img = cv2.imread(path,1).a(chǎn)stype('uint8')

   cv_img = cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)

   return cv_img

img = image_file_to_tensor('test_images/22673445.jpg')


input = {

 'instances': [img.tolist()]

最后,調(diào)用端點:

detections = predictor.predict(input)['predictions'][0]

問題:端點請求負載大小太大

這在使用小圖像時很好,因為API調(diào)用的請求負載足夠小。然而,當使用較大的圖片時,API返回413錯誤。這意味著有效負載超過了允許的大小,即6 MB。

當然,我們可以在調(diào)用端點之前調(diào)整圖像的大小,但我想使用批處理轉(zhuǎn)換作業(yè)。

解決方案:改用批處理轉(zhuǎn)換作業(yè)

使用SageMaker批量轉(zhuǎn)換作業(yè),你可以定義自己的最大負載大小,這樣我們就不會遇到413個錯誤。其次,這些作業(yè)可用于一次性處理全套圖像。

圖像需要存儲在S3存儲桶中。所有圖像都以批處理模式(名稱中的內(nèi)容)進行處理,預(yù)測也存儲在S3上。

為了使用批處理轉(zhuǎn)換作業(yè),我們再次定義了TensorFlowModel,但這次我們還定義了入口點和源目錄:

model_batch = TensorFlowModel(

   name=name_from_base('tf2-object-detection'),

   model_data=model_artifact,

   role=role,

   framework_version='2.2',

   entry_point='inference.py',

   source_dir='.',

inference.py代碼轉(zhuǎn)換模型的輸入和輸出數(shù)據(jù),如文檔中所述。此代碼需要將請求負載(圖像)更改為NumPy數(shù)組,并將其解析為列表對象。

從這個示例開始,我更改了代碼,使其加載圖像并將其轉(zhuǎn)換為NumPy數(shù)組。inference.py中input_handler函數(shù)更改為以下內(nèi)容:

import io

import json

import numpy as np

from PIL import Image

def input_handler(data, context):

   """ Pre-process request input before it is sent to TensorFlow Serving REST API

   Args:

       data (obj): the request data, in format of dict or string

       context (Context): an object containing request and configuration details

   Returns:

      (dict): a JSON-serializable dict that contains request body and headers

  """


   if context.request_content_type == "application/x-image":

       payload = data.read()              
       image = Image.open(io.BytesIO(payload))        
       array = np.a(chǎn)sarray(image)
       return json.dumps({'instances': [array.tolist()]})
   raise ValueError('{{"error": "unsupported content type {}"}}'.format(
       context.request_content_type or "unknown"))

注意,在上面的代碼中排除了output_handler函數(shù)。

此函數(shù)需要Python包NumPy和Pillow,它們未安裝在運行批處理推斷作業(yè)的機器上。

我們可以創(chuàng)建自己的鏡像并使用該鏡像(在TensorFlowModel對象初始化時使用image_uri關(guān)鍵字)。

也可以提供requirements.txt并將其存儲在筆記本所在的文件夾中(稱為source_dir=“.”)。該文件在鏡像引導(dǎo)期間用于使用pip安裝所需的包。內(nèi)容為:

numpy

pillow

首先,想使用OpenCV(就像在endpoint示例中一樣),但該軟件包不太容易安裝。

我們現(xiàn)在使用模型創(chuàng)建transformer對象,而不是將模型部署為模型端點:

input_path = "s3://bucket/input"

output_path = "s3://bucket/output"


tensorflow_serving_transformer = model_batch.transformer(

   instance_count=1,

   instance_type="ml.m5.large",

   max_concurrent_transforms=1,

   max_payload=5,

   output_path=output_path,

最后,使用transform:

tensorflow_serving_transformer.transform(

   input_path,

   content_type="application/x-image",

圖像由模型處理,結(jié)果將作為JSON文件最終在output_path bucket中。命名等于輸入文件名,后跟.out擴展名。你還可以調(diào)整和優(yōu)化實例類型、最大負載等。

最后

這很可能不是最具成本效益的方法,因為我們將圖像作為NumPy數(shù)組傳遞給轉(zhuǎn)換器。

此外,我們還可以在inference.py中調(diào)整output_handler函數(shù)壓縮并存儲在S3上的JSON,或僅返回相關(guān)檢測。

       原文標題 : SageMaker TensorFlow對象檢測模型

聲明: 本文由入駐維科號的作者撰寫,觀點僅代表作者本人,不代表OFweek立場。如有侵權(quán)或其他問題,請聯(lián)系舉報。

發(fā)表評論

0條評論,0人參與

請輸入評論內(nèi)容...

請輸入評論/評論長度6~500個字

您提交的評論過于頻繁,請輸入驗證碼繼續(xù)

暫無評論

暫無評論

人工智能 獵頭職位 更多
掃碼關(guān)注公眾號
OFweek人工智能網(wǎng)
獲取更多精彩內(nèi)容
文章糾錯
x
*文字標題:
*糾錯內(nèi)容:
聯(lián)系郵箱:
*驗 證 碼:

粵公網(wǎng)安備 44030502002758號