找回密码
 立即注册
问题
首先,我使用的是 TFX 0.21.2 版和 Tensorflow 2.1 版。

我建立了一个管道,主要以芝加哥出租车为例。执行 Trainer 组件时,我可以在日志中看到以下内容:

信息培训完成。模型写入 /root/airflow/tfx/pipelines/fish/Trainer/Model/9/serving_model_dir

检查上面的目录时,它是空的。我错过了什么?

这是我的 DAG 定义文件(忽略 import 语句):
  1. _pipeline_name = 'fish'
  2. _airflow_config = AirflowPipelineConfig(airflow_dag_config = {
  3.     'schedule_interval': None,
  4.     'start_date': datetime.datetime(2019, 1, 1),
  5. })
  6. _project_root = os.path.join(os.environ['HOME'], 'airflow')
  7. _data_root = os.path.join(_project_root, 'data', 'fish_data')
  8. _module_file = os.path.join(_project_root, 'dags', 'fishUtils.py')
  9. _serving_model_dir = os.path.join(_project_root, 'serving_model', _pipeline_name)
  10. _tfx_root = os.path.join(_project_root, 'tfx')
  11. _pipeline_root = os.path.join(_tfx_root, 'pipelines', _pipeline_name)
  12. _metadata_path = os.path.join(_tfx_root, 'metadata', _pipeline_name,
  13.                               'metadata.db')


  14. def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
  15.                      module_file: Text, serving_model_dir: Text,
  16.                      metadata_path: Text,
  17.                      direct_num_workers: int) -> pipeline.Pipeline:

  18.     examples = external_input(data_root)
  19.     example_gen = CsvExampleGen(input=examples)

  20.     statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

  21.     infer_schema = SchemaGen(
  22.       statistics=statistics_gen.outputs['statistics'],
  23.       infer_feature_shape=False)

  24.     validate_stats = ExampleValidator(
  25.       statistics=statistics_gen.outputs['statistics'],
  26.       schema=infer_schema.outputs['schema'])

  27.     trainer = Trainer(
  28.     examples=example_gen.outputs['examples'], schema=infer_schema.outputs['schema'],
  29.     module_file=_module_file, train_args= trainer_pb2.TrainArgs(num_steps=10000),
  30.     eval_args= trainer_pb2.EvalArgs(num_steps=5000))

  31.     model_validator = ModelValidator(
  32.       examples=example_gen.outputs['examples'],
  33.       model=trainer.outputs['model'])

  34.     pusher = Pusher(
  35.       model=trainer.outputs['model'],
  36.       model_blessing=model_validator.outputs['blessing'],
  37.       push_destination=pusher_pb2.PushDestination(
  38.         filesystem=pusher_pb2.PushDestination.Filesystem(
  39.           base_directory=_serving_model_dir)))

  40.     return pipeline.Pipeline(
  41.       pipeline_name=_pipeline_name,
  42.       pipeline_root=_pipeline_root,
  43.       components=[
  44.           example_gen,
  45.           statistics_gen,
  46.           infer_schema,
  47.           validate_stats,
  48.           trainer,
  49.           model_validator,
  50.           pusher],
  51.       enable_cache=True,
  52.       metadata_connection_config=metadata.sqlite_metadata_connection_config(
  53.           metadata_path),
  54.       beam_pipeline_args=['--direct_num_workers=%d' % direct_num_workers]
  55.   )

  56. runner = AirflowDagRunner(config = _airflow_config)
  57. DAG = runner.run(
  58.     _create_pipeline(
  59.         pipeline_name=_pipeline_name,
  60.         pipeline_root=_pipeline_root,
  61.         data_root=_data_root,
  62.         module_file=_module_file,
  63.         serving_model_dir=_serving_model_dir,
  64.         metadata_path=_metadata_path,
  65.         # 0 means auto-detect based on on the number of CPUs available during
  66.         # execution time.
  67.         direct_num_workers=0))
复制代码

这是我的模块文件:
  1. _DENSE_FLOAT_FEATURE_KEYS = ['length']

  2. real_valued_columns = [tf.feature_column.numeric_column('length')]

  3. def _eval_input_receiver_fn():

  4.   serialized_tf_example = tf.compat.v1.placeholder(
  5.       dtype=tf.string, shape=[None], name='input_example_tensor')

  6.   features = tf.io.parse_example(
  7.       serialized=serialized_tf_example,
  8.       features={
  9.           'length': tf.io.FixedLenFeature([], tf.float32),
  10.           'label': tf.io.FixedLenFeature([], tf.int64),
  11.       })

  12.   receiver_tensors = {'examples': serialized_tf_example}

  13.   return tfma.export.EvalInputReceiver(
  14.       features={'length' : features['length']},
  15.       receiver_tensors=receiver_tensors,
  16.       labels= features['label'],
  17.       )

  18. def parser(serialized_example):

  19.   features = tf.io.parse_single_example(
  20.       serialized_example,
  21.       features={
  22.           'length': tf.io.FixedLenFeature([], tf.float32),
  23.           'label': tf.io.FixedLenFeature([], tf.int64),
  24.       })
  25.   return ({'length' : features['length']}, features['label'])

  26. def _input_fn(filenames):
  27.   # TFRecordDataset doesn't directly accept paths with wildcards
  28.   filenames = tf.data.Dataset.list_files(filenames)
  29.   dataset = tf.data.TFRecordDataset(filenames, 'GZIP')
  30.   dataset = dataset.map(parser)
  31.   dataset = dataset.shuffle(2000)
  32.   dataset = dataset.batch(40)
  33.   dataset = dataset.repeat(10)

  34.   return dataset

  35. def trainer_fn(trainer_fn_args, schema):

  36.     estimator = tf.estimator.LinearClassifier(feature_columns=real_valued_columns)

  37.     train_input_fn = lambda: _input_fn(trainer_fn_args.train_files)

  38.     train_spec = tf.estimator.TrainSpec(
  39.       train_input_fn,
  40.       max_steps=trainer_fn_args.train_steps)

  41.     eval_input_fn = lambda: _input_fn(trainer_fn_args.eval_files)

  42.     eval_spec = tf.estimator.EvalSpec(
  43.       eval_input_fn,
  44.       steps=trainer_fn_args.eval_steps,
  45.       name='fish-eval')

  46.     receiver_fn = lambda: _eval_input_receiver_fn()

  47.     return {
  48.       'estimator': estimator,
  49.       'train_spec': train_spec,
  50.       'eval_spec': eval_spec,
  51.       'eval_input_receiver_fn': receiver_fn
  52.   }
复制代码

在此先感谢您的帮助!

回答
为遇到与我相同问题的任何人发布解决方案。

模型没有写入文件系统的原因是估计器需要一个配置参数来知道在哪里写入模型。

以下对 trainer_fn 函数的修改应该可以解决问题。
  1. run_config = tf.estimator.RunConfig(save_checkpoints_steps=999, keep_checkpoint_max=1)  

  2. run_config = run_config.replace(model_dir=trainer_fn_args.serving_model_dir)

  3. estimator=tf.estimator.LinearClassifier(feature_columns=real_valued_columns,config=run_config)
复制代码






上一篇:使用底部选项卡导航器时不显示标题
下一篇:更改任务触发器但不反映 OIM 流程表单上的字段值