Skip to content

Commit

Permalink
[AIRFLOW-243] Create NamedHivePartitionSensor
Browse files Browse the repository at this point in the history
Closes apache#1593 from zodiac/create-NamedHivePartitionSensor
  • Loading branch information
ldct authored and aoen committed Jun 29, 2016
1 parent 4a84a57 commit bf28de4
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 4 deletions.
42 changes: 41 additions & 1 deletion airflow/hooks/hive_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import re
import subprocess
from tempfile import NamedTemporaryFile
import hive_metastore

from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
Expand Down Expand Up @@ -321,7 +322,17 @@ def get_conn(self):
return self.metastore

def check_for_partition(self, schema, table, partition):
"""Checks whether a partition exists
"""
Checks whether a partition exists
:param schema: Name of hive schema (database) @table belongs to
:type schema: string
:param table: Name of hive table @partition belongs to
:type schema: string
:partition: Expression that matches the partitions to check for
(eg `a = 'b' AND c = 'd'`)
:type schema: string
:rtype: boolean
>>> hh = HiveMetastoreHook()
>>> t = 'static_babynames_partitioned'
Expand All @@ -337,6 +348,35 @@ def check_for_partition(self, schema, table, partition):
else:
return False

def check_for_named_partition(self, schema, table, partition_name):
"""
Checks whether a partition with a given name exists
:param schema: Name of hive schema (database) @table belongs to
:type schema: string
:param table: Name of hive table @partition belongs to
:type schema: string
:partition: Name of the partitions to check for (eg `a=b/c=d`)
:type schema: string
:rtype: boolean
>>> hh = HiveMetastoreHook()
>>> t = 'static_babynames_partitioned'
>>> hh.check_for_named_partition('airflow', t, "ds=2015-01-01")
True
>>> hh.check_for_named_partition('airflow', t, "ds=xxx")
False
"""
self.metastore._oprot.trans.open()
try:
self.metastore.get_partition_by_name(
schema, table, partition_name)
return True
except hive_metastore.ttypes.NoSuchObjectException:
return False
finally:
self.metastore._oprot.trans.close()

def get_table(self, table_name, db='default'):
"""Get a metastore table object
Expand Down
1 change: 1 addition & 0 deletions airflow/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
'HivePartitionSensor',
'HttpSensor',
'MetastorePartitionSensor',
'NamedHivePartitionSensor',
'S3KeySensor',
'S3PrefixSensor',
'SqlSensor',
Expand Down
81 changes: 78 additions & 3 deletions airflow/operators/sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,15 +222,91 @@ def poke(self, context):
return count


class NamedHivePartitionSensor(BaseSensorOperator):
"""
Waits for a set of partitions to show up in Hive.
:param partition_names: List of fully qualified names of the
partitions to wait for. A fully qualified name is of the
form schema.table/pk1=pv1/pk2=pv2, for example,
default.users/ds=2016-01-01. This is passed as is to the metastore
Thrift client "get_partitions_by_name" method. Note that
you cannot use logical operators as in HivePartitionSensor.
:type partition_names: list of strings
:param metastore_conn_id: reference to the metastore thrift service
connection id
:type metastore_conn_id: str
"""

template_fields = ('partition_names', )

@apply_defaults
def __init__(
self,
partition_names,
metastore_conn_id='metastore_default',
poke_interval=60*3,
*args,
**kwargs):
super(NamedHivePartitionSensor, self).__init__(
poke_interval=poke_interval, *args, **kwargs)

if isinstance(partition_names, basestring):
raise TypeError('partition_names must be an array of strings')

for partition_name in partition_names:
self.parse_partition_name(partition_name)

self.metastore_conn_id = metastore_conn_id
self.partition_names = partition_names
self.next_poke_idx = 0

def parse_partition_name(self, partition):
try:
schema, table_partition = partition.split('.')
table, partition = table_partition.split('/', 1)
return schema, table, partition
except ValueError as e:
raise ValueError('Could not parse ' + partition)

def poke(self, context):

if not hasattr(self, 'hook'):
self.hook = airflow.hooks.hive_hooks.HiveMetastoreHook(
metastore_conn_id=self.metastore_conn_id)

def poke_partition(partition):

schema, table, partition = self.parse_partition_name(partition)

logging.info(
'Poking for {schema}.{table}/{partition}'.format(**locals())
)
return self.hook.check_for_named_partition(
schema, table, partition)

while self.next_poke_idx < len(self.partition_names):
if poke_partition(self.partition_names[self.next_poke_idx]):
self.next_poke_idx += 1
else:
return False

return True


class HivePartitionSensor(BaseSensorOperator):
"""
Waits for a partition to show up in Hive
Waits for a partition to show up in Hive.
Note: Because @partition supports general logical operators, it
can be inefficient. Consider using NamedHivePartitionSensor instead if
you don't need the full flexibility of HivePartitionSensor.
:param table: The name of the table to wait for, supports the dot
notation (my_database.my_table)
:type table: string
:param partition: The partition clause to wait for. This is passed as
is to the Metastore Thrift client "get_partitions_by_filter" method,
is to the metastore Thrift client "get_partitions_by_filter" method,
and apparently supports SQL like notation as in `ds='2015-01-01'
AND type='value'` and > < sings as in "ds>=2015-01-01"
:type partition: string
Expand Down Expand Up @@ -264,7 +340,6 @@ def poke(self, context):
'Poking for table {self.schema}.{self.table}, '
'partition {self.partition}'.format(**locals()))
if not hasattr(self, 'hook'):
import airflow.hooks.hive_hooks
self.hook = airflow.hooks.hive_hooks.HiveMetastoreHook(
metastore_conn_id=self.metastore_conn_id)
return self.hook.check_for_partition(
Expand Down
31 changes: 31 additions & 0 deletions tests/operators/hive_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import os
import unittest
import nose


DEFAULT_DATE = datetime.datetime(2015, 1, 1)
Expand Down Expand Up @@ -163,6 +164,36 @@ def test_hive_stats(self):
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)

def test_named_hive_partition_sensor(self):
t = operators.sensors.NamedHivePartitionSensor(
task_id='hive_partition_check',
partition_names=["airflow.static_babynames_partitioned/ds={{ds}}"],
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)

def test_named_hive_partition_sensor_succeeds_on_multiple_partitions(self):
t = operators.sensors.NamedHivePartitionSensor(
task_id='hive_partition_check',
partition_names=[
"airflow.static_babynames_partitioned/ds={{ds}}",
"airflow.static_babynames_partitioned/ds={{ds}}"
],
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)

@nose.tools.raises(airflow.exceptions.AirflowSensorTimeout)
def test_named_hive_partition_sensor_times_out_on_nonexistent_partition(self):
t = operators.sensors.NamedHivePartitionSensor(
task_id='hive_partition_check',
partition_names=[
"airflow.static_babynames_partitioned/ds={{ds}}",
"airflow.static_babynames_partitioned/ds=nonexistent"
],
poke_interval=0.1,
timeout=1,
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)

def test_hive_partition_sensor(self):
t = operators.sensors.HivePartitionSensor(
task_id='hive_partition_check',
Expand Down

0 comments on commit bf28de4

Please sign in to comment.