# NOTE: This script is not fully tested. import json import sys from pyspark.sql import SparkSession from pyspark import SparkConf from pyspark.sql.functions import col, udf, lit from pyspark.sql.types import MapType, StringType, FloatType from preprocess_content import fasttext_preprocess_func from fasttext_infer import fasttext_infer def get_fasttext_pred(content: str): """Filter the prediction result. Args: content (str): text. Returns: Optional[str]: json string with pred_label and pred_score. """ norm_content = fasttext_preprocess_func(content) label, score = fasttext_infer(norm_content) if label == '__label__pos': return json.dumps({'pred_label': label, 'pred_score': score}, ensure_ascii=False) else: return None if __name__ == "__main__": input_path = sys.argv[1] save_path = sys.argv[2] content_key = "content" spark = (SparkSession.builder.enableHiveSupport() .config("hive.exec.dynamic.partition", "true") .config("hive.exec.dynamic.partition.mode", "nonstrict") .appName("FastTextInference") .getOrCreate()) predict_udf = udf(get_fasttext_pred) # df = spark.read.json(input_path) df = spark.read.parquet(input_path) df = df.withColumn("fasttext_pred", predict_udf(col(content_key))) df = df.filter(col("fasttext_pred").isNotNull()) # df.coalesce(1000).write.mode("overwrite").json(save_path) df.coalesce(1000).write.mode("overwrite").parquet(save_path) spark.stop()