diff --git a/SparseDiffEngine b/SparseDiffEngine index 0a3070d..339a4cf 160000 --- a/SparseDiffEngine +++ b/SparseDiffEngine @@ -1 +1 @@ -Subproject commit 0a3070d329cafd9d9fe14970ec045a674fbe797c +Subproject commit 339a4cf9c4b7a1e6a57273a388f357391bd0f86e diff --git a/sparsediffpy/_bindings/atoms/kron.h b/sparsediffpy/_bindings/atoms/kron.h new file mode 100644 index 0000000..40eabc8 --- /dev/null +++ b/sparsediffpy/_bindings/atoms/kron.h @@ -0,0 +1,54 @@ +#ifndef ATOM_KRON_H +#define ATOM_KRON_H + +#include "common.h" + +/* Kronecker product Z = kron(A, B). Exactly one operand is variable-free and is + * passed as the parameter capsule (wrap constants with make_parameter and + * param_id=-1 / PARAM_FIXED); the other carries the variables and is passed as + * the child capsule. + * + * Python signature: + * make_kron(param_capsule, child_capsule, const_is_left, p, q, r, s) + * - const_is_left: 1 -> A=param, B=child; 0 -> A=child, B=param. + * - (p, q): A's dims; (r, s): B's dims. */ +static PyObject *py_make_kron(PyObject *self, PyObject *args) +{ + PyObject *param_capsule; + PyObject *child_capsule; + int const_is_left, p, q, r, s; + + if (!PyArg_ParseTuple(args, "OOiiiii", ¶m_capsule, &child_capsule, + &const_is_left, &p, &q, &r, &s)) + { + return NULL; + } + + expr *param_node = + (expr *) PyCapsule_GetPointer(param_capsule, EXPR_CAPSULE_NAME); + if (!param_node) + { + PyErr_SetString(PyExc_ValueError, "invalid parameter capsule"); + return NULL; + } + + expr *child = + (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_kron(param_node, child, const_is_left, p, q, r, s); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create kron node"); + return NULL; + } + + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_KRON_H */ diff --git a/sparsediffpy/_bindings/bindings.c b/sparsediffpy/_bindings/bindings.c index d394fa4..6fd9c07 100644 --- a/sparsediffpy/_bindings/bindings.c +++ b/sparsediffpy/_bindings/bindings.c @@ -16,6 +16,7 @@ #include "atoms/getters.h" #include "atoms/hstack.h" #include "atoms/index.h" +#include "atoms/kron.h" #include "atoms/left_matmul.h" #include "atoms/log.h" #include "atoms/logistic.h" @@ -94,6 +95,8 @@ static PyMethodDef DNLPMethods[] = { "Create elementwise multiply node"}, {"make_convolve", py_make_convolve, METH_VARARGS, "Create 1D full convolution node: y = conv(kernel_param, child)"}, + {"make_kron", py_make_kron, METH_VARARGS, + "Create Kronecker product node: Z = kron(A, B)"}, {"make_matmul", py_make_matmul, METH_VARARGS, "Create matrix multiplication node (Z = X @ Y)"}, {"make_param_scalar_mult", py_make_param_scalar_mult, METH_VARARGS,