2024-01-23-pyspark调用tf2.0模型进行分布式预测

代码来自-github

代码

Load Dependencies

import tensorflow as tf
from pyspark import SparkFiles
from pyspark.sql.functions import udf
import pyspark.sql.types as T
from pyspark.sql import Row
print(tf.__version__)

Fetch SavedModel from S3/GCS and Distribute to Nodes

S3_PREFIX = "s3://"

MODEL_BUCKET = "my-models-bucket"
MODEL_PATH = "path/to/my/model/dir"
MODEL_NAME = "model"

S3_MODEL = f"{S3_PREFIX}{MODEL_BUCKET}/{MODEL_PATH}/{MODEL_NAME}"

print("Fetching model", S3_MODEL)

# Add model to all workers
spark.sparkContext.addFile(S3_MODEL, recursive=True)

Create the Input Dataframe

# In this example, the SavedModel has the following format:

# inputs = tf.keras.Input(shape=(784,), name='img')
# x = layers.Dense(64, activation='relu')(inputs)
# x = layers.Dense(64, activation='relu')(x)
# outputs = layers.Dense(10, activation='softmax')(x)
# model = tf.keras.Model(inputs=inputs, # outputs=outputs, name='mnist_model')

(_, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
x_test = x_test.reshape(10000, 784).astype('float32') / 255

rows = list(map(lambda n: Row(img=[n.tolist()]), x_test))

schema = T.StructType([T.StructField('img',T.ArrayType(T.ArrayType(T.FloatType())))])

input_df = spark.createDataFrame(rows, schema=schema)

Memoize Retrieval of the Saved Model

# Simple memoization helper with a single cache key
def compute_once(f):
    K = '0'
    cache = {}
    
    def wrapper(x):
        # Set on first call
        if K not in cache:
            cache[K] = f(x)
        
        return cache[K]

    return wrapper
    

def load_model(model_name):
    # Models are saved under the SparkFiles root directory
    root_dir = SparkFiles.getRootDirectory()
    export_dir = f"{root_dir}/{model_name}"
    
    return tf.saved_model.load(export_dir, tags=['serve'])
    

# Only load the model once per worker!
# The reduced disk IO makes prediction much faster
memo_model_load = compute_once(load_model)

def get_model_prediction(model_name, input):
    """
    Note: 
        TF session is scoped to where the model is loaded.
        All calls to the model's ConcreteFunciton must be in the same scope as
        the loaded model (i.e in the same function!)
        
        If not, TF will throw errors for undefined/ variables
    """
    # Load the predict function (from disk or cache)
    m = memo_model_load(model_name)
    
    # Save the predict signature
    pred_func = m.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
    
    return pred_func(input)

Create the Predict UDF

# Decorator with return type of UDF
@udf("array<array<float>>")
def infer(data):
    # Cast the input to a Tensor
    input_data = tf.constant(data)
    
    # Returns a dict of the form { TENSOR_NAME: Tensor }
    outputs = get_model_prediction(MODEL_NAME, input_data)

    # Assuming we have a single output
    output_tensor = list(outputs.values())[0]
    
    # Convert back to regular python
    output_value = output_tensor.numpy().tolist()
    
    return output_value

Infer on the Dataset 🎉

Infer on the Dataset 🎉

## 这里其实更建议使用mapPartiton的方式,速度会更快
predictions_df = input_df.withColumn("predictions", infer("img"))

# All done :) 
predictions_df.show(vertical=True)
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容