diff --git a/dbt/include/spark/macros/utils/any_value.sql b/dbt/include/spark/macros/utils/any_value.sql new file mode 100644 index 000000000..eb0a019b3 --- /dev/null +++ b/dbt/include/spark/macros/utils/any_value.sql @@ -0,0 +1,5 @@ +{% macro spark__any_value(expression) -%} + {#-- return any value (non-deterministic) --#} + first({{ expression }}) + +{%- endmacro %} diff --git a/dbt/include/spark/macros/utils/assert_not_null.sql b/dbt/include/spark/macros/utils/assert_not_null.sql new file mode 100644 index 000000000..e5454bce9 --- /dev/null +++ b/dbt/include/spark/macros/utils/assert_not_null.sql @@ -0,0 +1,9 @@ +{% macro assert_not_null(function, arg) -%} + {{ return(adapter.dispatch('assert_not_null', 'dbt')(function, arg)) }} +{%- endmacro %} + +{% macro spark__assert_not_null(function, arg) %} + + coalesce({{function}}({{arg}}), nvl2({{function}}({{arg}}), assert_true({{function}}({{arg}}) is not null), null)) + +{% endmacro %} diff --git a/dbt/include/spark/macros/utils/bool_or.sql b/dbt/include/spark/macros/utils/bool_or.sql new file mode 100644 index 000000000..60d705eb3 --- /dev/null +++ b/dbt/include/spark/macros/utils/bool_or.sql @@ -0,0 +1,11 @@ +{#-- Spark v3 supports 'bool_or' and 'any', but Spark v2 needs to use 'max' for this + -- https://spark.apache.org/docs/latest/api/sql/index.html#any + -- https://spark.apache.org/docs/latest/api/sql/index.html#bool_or + -- https://spark.apache.org/docs/latest/api/sql/index.html#max +#} + +{% macro spark__bool_or(expression) -%} + + max({{ expression }}) + +{%- endmacro %} diff --git a/dbt/include/spark/macros/utils/concat.sql b/dbt/include/spark/macros/utils/concat.sql new file mode 100644 index 000000000..30f1a420e --- /dev/null +++ b/dbt/include/spark/macros/utils/concat.sql @@ -0,0 +1,3 @@ +{% macro spark__concat(fields) -%} + concat({{ fields|join(', ') }}) +{%- endmacro %} diff --git a/dbt/include/spark/macros/utils/dateadd.sql b/dbt/include/spark/macros/utils/dateadd.sql new file mode 100644 index 000000000..e2a20d0f2 --- /dev/null +++ b/dbt/include/spark/macros/utils/dateadd.sql @@ -0,0 +1,62 @@ +{% macro spark__dateadd(datepart, interval, from_date_or_timestamp) %} + + {%- set clock_component -%} + {# make sure the dates + timestamps are real, otherwise raise an error asap #} + to_unix_timestamp({{ assert_not_null('to_timestamp', from_date_or_timestamp) }}) + - to_unix_timestamp({{ assert_not_null('date', from_date_or_timestamp) }}) + {%- endset -%} + + {%- if datepart in ['day', 'week'] -%} + + {%- set multiplier = 7 if datepart == 'week' else 1 -%} + + to_timestamp( + to_unix_timestamp( + date_add( + {{ assert_not_null('date', from_date_or_timestamp) }}, + cast({{interval}} * {{multiplier}} as int) + ) + ) + {{clock_component}} + ) + + {%- elif datepart in ['month', 'quarter', 'year'] -%} + + {%- set multiplier -%} + {%- if datepart == 'month' -%} 1 + {%- elif datepart == 'quarter' -%} 3 + {%- elif datepart == 'year' -%} 12 + {%- endif -%} + {%- endset -%} + + to_timestamp( + to_unix_timestamp( + add_months( + {{ assert_not_null('date', from_date_or_timestamp) }}, + cast({{interval}} * {{multiplier}} as int) + ) + ) + {{clock_component}} + ) + + {%- elif datepart in ('hour', 'minute', 'second', 'millisecond', 'microsecond') -%} + + {%- set multiplier -%} + {%- if datepart == 'hour' -%} 3600 + {%- elif datepart == 'minute' -%} 60 + {%- elif datepart == 'second' -%} 1 + {%- elif datepart == 'millisecond' -%} (1/1000000) + {%- elif datepart == 'microsecond' -%} (1/1000000) + {%- endif -%} + {%- endset -%} + + to_timestamp( + {{ assert_not_null('to_unix_timestamp', from_date_or_timestamp) }} + + cast({{interval}} * {{multiplier}} as int) + ) + + {%- else -%} + + {{ exceptions.raise_compiler_error("macro dateadd not implemented for datepart ~ '" ~ datepart ~ "' ~ on Spark") }} + + {%- endif -%} + +{% endmacro %} diff --git a/dbt/include/spark/macros/utils/datediff.sql b/dbt/include/spark/macros/utils/datediff.sql new file mode 100644 index 000000000..d0e684c47 --- /dev/null +++ b/dbt/include/spark/macros/utils/datediff.sql @@ -0,0 +1,107 @@ +{% macro spark__datediff(first_date, second_date, datepart) %} + + {%- if datepart in ['day', 'week', 'month', 'quarter', 'year'] -%} + + {# make sure the dates are real, otherwise raise an error asap #} + {% set first_date = assert_not_null('date', first_date) %} + {% set second_date = assert_not_null('date', second_date) %} + + {%- endif -%} + + {%- if datepart == 'day' -%} + + datediff({{second_date}}, {{first_date}}) + + {%- elif datepart == 'week' -%} + + case when {{first_date}} < {{second_date}} + then floor(datediff({{second_date}}, {{first_date}})/7) + else ceil(datediff({{second_date}}, {{first_date}})/7) + end + + -- did we cross a week boundary (Sunday)? + + case + when {{first_date}} < {{second_date}} and dayofweek({{second_date}}) < dayofweek({{first_date}}) then 1 + when {{first_date}} > {{second_date}} and dayofweek({{second_date}}) > dayofweek({{first_date}}) then -1 + else 0 end + + {%- elif datepart == 'month' -%} + + case when {{first_date}} < {{second_date}} + then floor(months_between(date({{second_date}}), date({{first_date}}))) + else ceil(months_between(date({{second_date}}), date({{first_date}}))) + end + + -- did we cross a month boundary? + + case + when {{first_date}} < {{second_date}} and dayofmonth({{second_date}}) < dayofmonth({{first_date}}) then 1 + when {{first_date}} > {{second_date}} and dayofmonth({{second_date}}) > dayofmonth({{first_date}}) then -1 + else 0 end + + {%- elif datepart == 'quarter' -%} + + case when {{first_date}} < {{second_date}} + then floor(months_between(date({{second_date}}), date({{first_date}}))/3) + else ceil(months_between(date({{second_date}}), date({{first_date}}))/3) + end + + -- did we cross a quarter boundary? + + case + when {{first_date}} < {{second_date}} and ( + (dayofyear({{second_date}}) - (quarter({{second_date}}) * 365/4)) + < (dayofyear({{first_date}}) - (quarter({{first_date}}) * 365/4)) + ) then 1 + when {{first_date}} > {{second_date}} and ( + (dayofyear({{second_date}}) - (quarter({{second_date}}) * 365/4)) + > (dayofyear({{first_date}}) - (quarter({{first_date}}) * 365/4)) + ) then -1 + else 0 end + + {%- elif datepart == 'year' -%} + + year({{second_date}}) - year({{first_date}}) + + {%- elif datepart in ('hour', 'minute', 'second', 'millisecond', 'microsecond') -%} + + {%- set divisor -%} + {%- if datepart == 'hour' -%} 3600 + {%- elif datepart == 'minute' -%} 60 + {%- elif datepart == 'second' -%} 1 + {%- elif datepart == 'millisecond' -%} (1/1000) + {%- elif datepart == 'microsecond' -%} (1/1000000) + {%- endif -%} + {%- endset -%} + + case when {{first_date}} < {{second_date}} + then ceil(( + {# make sure the timestamps are real, otherwise raise an error asap #} + {{ assert_not_null('to_unix_timestamp', assert_not_null('to_timestamp', second_date)) }} + - {{ assert_not_null('to_unix_timestamp', assert_not_null('to_timestamp', first_date)) }} + ) / {{divisor}}) + else floor(( + {{ assert_not_null('to_unix_timestamp', assert_not_null('to_timestamp', second_date)) }} + - {{ assert_not_null('to_unix_timestamp', assert_not_null('to_timestamp', first_date)) }} + ) / {{divisor}}) + end + + {% if datepart == 'millisecond' %} + + cast(date_format({{second_date}}, 'SSS') as int) + - cast(date_format({{first_date}}, 'SSS') as int) + {% endif %} + + {% if datepart == 'microsecond' %} + {% set capture_str = '[0-9]{4}-[0-9]{2}-[0-9]{2}.[0-9]{2}:[0-9]{2}:[0-9]{2}.([0-9]{6})' %} + -- Spark doesn't really support microseconds, so this is a massive hack! + -- It will only work if the timestamp-string is of the format + -- 'yyyy-MM-dd-HH mm.ss.SSSSSS' + + cast(regexp_extract({{second_date}}, '{{capture_str}}', 1) as int) + - cast(regexp_extract({{first_date}}, '{{capture_str}}', 1) as int) + {% endif %} + + {%- else -%} + + {{ exceptions.raise_compiler_error("macro datediff not implemented for datepart ~ '" ~ datepart ~ "' ~ on Spark") }} + + {%- endif -%} + +{% endmacro %} diff --git a/dbt/include/spark/macros/utils/listagg.sql b/dbt/include/spark/macros/utils/listagg.sql new file mode 100644 index 000000000..3577edb71 --- /dev/null +++ b/dbt/include/spark/macros/utils/listagg.sql @@ -0,0 +1,17 @@ +{% macro spark__listagg(measure, delimiter_text, order_by_clause, limit_num) -%} + + {% if order_by_clause %} + {{ exceptions.warn("order_by_clause is not supported for listagg on Spark/Databricks") }} + {% endif %} + + {% set collect_list %} collect_list({{ measure }}) {% endset %} + + {% set limited %} slice({{ collect_list }}, 1, {{ limit_num }}) {% endset %} + + {% set collected = limited if limit_num else collect_list %} + + {% set final %} array_join({{ collected }}, {{ delimiter_text }}) {% endset %} + + {% do return(final) %} + +{%- endmacro %} diff --git a/dbt/include/spark/macros/utils/split_part.sql b/dbt/include/spark/macros/utils/split_part.sql new file mode 100644 index 000000000..d5ae30924 --- /dev/null +++ b/dbt/include/spark/macros/utils/split_part.sql @@ -0,0 +1,23 @@ +{% macro spark__split_part(string_text, delimiter_text, part_number) %} + + {% set delimiter_expr %} + + -- escape if starts with a special character + case when regexp_extract({{ delimiter_text }}, '([^A-Za-z0-9])(.*)', 1) != '_' + then concat('\\', {{ delimiter_text }}) + else {{ delimiter_text }} end + + {% endset %} + + {% set split_part_expr %} + + split( + {{ string_text }}, + {{ delimiter_expr }} + )[({{ part_number - 1 }})] + + {% endset %} + + {{ return(split_part_expr) }} + +{% endmacro %} diff --git a/tests/functional/adapter/test_basic.py b/tests/functional/adapter/test_basic.py index 70f3267a4..e0cf2f7fe 100644 --- a/tests/functional/adapter/test_basic.py +++ b/tests/functional/adapter/test_basic.py @@ -64,7 +64,7 @@ def project_config_update(self): } -#hese tests were not enabled in the dbtspec files, so skipping here. +# These tests were not enabled in the dbtspec files, so skipping here. # Error encountered was: Error running query: java.lang.ClassNotFoundException: delta.DefaultSource @pytest.mark.skip_profile('apache_spark', 'spark_session') class TestSnapshotTimestampSpark(BaseSnapshotTimestamp): @@ -79,5 +79,7 @@ def project_config_update(self): } } + +@pytest.mark.skip_profile('spark_session') class TestBaseAdapterMethod(BaseAdapterMethod): pass \ No newline at end of file diff --git a/tests/functional/adapter/utils/fixture_listagg.py b/tests/functional/adapter/utils/fixture_listagg.py new file mode 100644 index 000000000..0262ca234 --- /dev/null +++ b/tests/functional/adapter/utils/fixture_listagg.py @@ -0,0 +1,61 @@ +# SparkSQL does not support 'order by' for its 'listagg' equivalent +# the argument is ignored, so let's ignore those fields when checking equivalency + +models__test_listagg_no_order_by_sql = """ +with data as ( + select * from {{ ref('data_listagg') }} +), +data_output as ( + select * from {{ ref('data_listagg_output') }} +), +calculate as ( +/* + + select + group_col, + {{ listagg('string_text', "'_|_'", "order by order_col") }} as actual, + 'bottom_ordered' as version + from data + group by group_col + union all + select + group_col, + {{ listagg('string_text', "'_|_'", "order by order_col", 2) }} as actual, + 'bottom_ordered_limited' as version + from data + group by group_col + union all + +*/ + select + group_col, + {{ listagg('string_text', "', '") }} as actual, + 'comma_whitespace_unordered' as version + from data + where group_col = 3 + group by group_col + union all + select + group_col, + {{ listagg('DISTINCT string_text', "','") }} as actual, + 'distinct_comma' as version + from data + where group_col = 3 + group by group_col + union all + select + group_col, + {{ listagg('string_text') }} as actual, + 'no_params' as version + from data + where group_col = 3 + group by group_col +) +select + calculate.actual, + data_output.expected +from calculate +left join data_output +on calculate.group_col = data_output.group_col +and calculate.version = data_output.version +""" diff --git a/tests/functional/adapter/utils/test_utils.py b/tests/functional/adapter/utils/test_utils.py new file mode 100644 index 000000000..9137c2f75 --- /dev/null +++ b/tests/functional/adapter/utils/test_utils.py @@ -0,0 +1,121 @@ +import pytest + +from dbt.tests.adapter.utils.test_any_value import BaseAnyValue +from dbt.tests.adapter.utils.test_bool_or import BaseBoolOr +from dbt.tests.adapter.utils.test_cast_bool_to_text import BaseCastBoolToText +from dbt.tests.adapter.utils.test_concat import BaseConcat +from dbt.tests.adapter.utils.test_dateadd import BaseDateAdd +from dbt.tests.adapter.utils.test_datediff import BaseDateDiff +from dbt.tests.adapter.utils.test_date_trunc import BaseDateTrunc +from dbt.tests.adapter.utils.test_escape_single_quotes import BaseEscapeSingleQuotesQuote +from dbt.tests.adapter.utils.test_escape_single_quotes import BaseEscapeSingleQuotesBackslash +from dbt.tests.adapter.utils.test_except import BaseExcept +from dbt.tests.adapter.utils.test_hash import BaseHash +from dbt.tests.adapter.utils.test_intersect import BaseIntersect +from dbt.tests.adapter.utils.test_last_day import BaseLastDay +from dbt.tests.adapter.utils.test_length import BaseLength +from dbt.tests.adapter.utils.test_position import BasePosition +from dbt.tests.adapter.utils.test_replace import BaseReplace +from dbt.tests.adapter.utils.test_right import BaseRight +from dbt.tests.adapter.utils.test_safe_cast import BaseSafeCast +from dbt.tests.adapter.utils.test_split_part import BaseSplitPart +from dbt.tests.adapter.utils.test_string_literal import BaseStringLiteral + +# requires modification +from dbt.tests.adapter.utils.test_listagg import BaseListagg +from dbt.tests.adapter.utils.fixture_listagg import models__test_listagg_yml +from tests.functional.adapter.utils.fixture_listagg import models__test_listagg_no_order_by_sql + + +class TestAnyValue(BaseAnyValue): + pass + + +class TestBoolOr(BaseBoolOr): + pass + + +class TestCastBoolToText(BaseCastBoolToText): + pass + + +@pytest.mark.skip_profile('spark_session') +class TestConcat(BaseConcat): + pass + + +class TestDateAdd(BaseDateAdd): + pass + + +@pytest.mark.skip_profile('spark_session') +class TestDateDiff(BaseDateDiff): + pass + + +class TestDateTrunc(BaseDateTrunc): + pass + + +class TestEscapeSingleQuotes(BaseEscapeSingleQuotesQuote): + pass + + +class TestExcept(BaseExcept): + pass + + +@pytest.mark.skip_profile('spark_session') +class TestHash(BaseHash): + pass + + +class TestIntersect(BaseIntersect): + pass + + +class TestLastDay(BaseLastDay): + pass + + +class TestLength(BaseLength): + pass + + +# SparkSQL does not support 'order by' for its 'listagg' equivalent +# the argument is ignored, so let's ignore those fields when checking equivalency +class TestListagg(BaseListagg): + @pytest.fixture(scope="class") + def models(self): + return { + "test_listagg.yml": models__test_listagg_yml, + "test_listagg.sql": self.interpolate_macro_namespace( + models__test_listagg_no_order_by_sql, "listagg" + ), + } + + +class TestPosition(BasePosition): + pass + + +@pytest.mark.skip_profile('spark_session') +class TestReplace(BaseReplace): + pass + + +@pytest.mark.skip_profile('spark_session') +class TestRight(BaseRight): + pass + + +class TestSafeCast(BaseSafeCast): + pass + + +class TestSplitPart(BaseSplitPart): + pass + + +class TestStringLiteral(BaseStringLiteral): + pass