diff --git a/src/AggregationFunction/FunctionSum.php b/src/AggregationFunction/FunctionSum.php index 3ef1cf80..e1b89513 100644 --- a/src/AggregationFunction/FunctionSum.php +++ b/src/AggregationFunction/FunctionSum.php @@ -8,23 +8,32 @@ namespace Ublaboo\DataGrid\AggregationFunction; -class FunctionSum implements IAggregationFunction -{ +use DibiFluent; +use Doctrine\ORM\QueryBuilder; +use Nette\Utils\Strings; +final class FunctionSum implements IAggregationFunction +{ /** * @var string */ - protected $column; + private $column; /** * @var int */ - protected $result = 0; + private $result = 0; /** * @var int */ - protected $dataType; + private $dataType; + + /** + * @var callable + */ + private $renderer; + /** * @param string $column @@ -47,23 +56,25 @@ public function getFilterDataType() /** - * @param mixed $data_source + * @param mixed $dataSource * @return void */ - public function processDataSource($data_source) + public function processDataSource($dataSource) { - if ($data_source instanceof \DibiFluent) { - $connection = $data_source->getConnection(); + if ($dataSource instanceof DibiFluent) { + $connection = $dataSource->getConnection(); $this->result = $connection->select('SUM(%n) AS sum', $this->column) - ->from($data_source, 's') + ->from($dataSource, 's') ->fetch() ->sum; } - if ($data_source instanceof \Doctrine\ORM\QueryBuilder) { - $column = \Nette\Utils\Strings::contains($this->column, '.') - ? $this->column - : current($data_source->getRootAliases()).'.'.$this->column; - $this->result = $data_source + + if ($dataSource instanceof QueryBuilder) { + $column = Strings::contains($this->column, '.') + ? $this->column + : current($dataSource->getRootAliases()).'.'.$this->column; + + $this->result = $dataSource ->select(sprintf('SUM(%s)', $column)) ->getQuery() ->getSingleScalarResult(); @@ -72,11 +83,27 @@ public function processDataSource($data_source) /** - * @return int + * @return mixed */ public function renderResult() { - return $this->result; + $result = $this->result; + + if (isset($this->renderer)) { + $result = call_user_func($this->renderer, $result); + } + + return $result; } + + /** + * @param callable|NULL $callback + * @return static + */ + public function setRenderer(callable $callback = NULL) + { + $this->renderer = $callback; + return $this; + } }