Skip to content

Instantly share code, notes, and snippets.

@chtlp
Forked from pv/example.pyx
Created January 25, 2014 17:07
Show Gist options
  • Save chtlp/8619639 to your computer and use it in GitHub Desktop.
Save chtlp/8619639 to your computer and use it in GitHub Desktop.
import numpy as np
import scipy.linalg.blas
cdef extern from "f2pyptr.h":
void *f2py_pointer(object) except NULL
ctypedef int dgemm_t(
char *transa, char *transb,
int *m, int *n, int *k,
double *alpha,
double *a, int *lda,
double *b, int *ldb,
double *beta,
double *c, int *ldc)
# Since Scipy >= 0.12.0
cdef dgemm_t *dgemm = <dgemm_t*>f2py_pointer(scipy.linalg.blas.dgemm._cpointer)
def myfunc():
cdef double[::1,:] a, b, c
cdef int m, n, k, lda, ldb, ldc
cdef double alpha, beta
a = np.array([[1, 2], [3, 4]], float, order="F")
b = np.array([[5, 6], [7, 8]], float, order="F")
c = np.empty((2, 2), float, order="F")
alpha = 1.0
beta = 0.0
lda = 2
ldb = 2
ldc = 2
m = 2
n = 2
k = 2
dgemm("N", "N", &m, &n, &k, &alpha, &a[0,0], &lda, &b[0,0], &ldb, &beta, &c[0,0], &ldc)
print(np.asarray(c))
print(np.dot(a, b))
#ifndef F2PYPTR_H_
#define F2PYPTR_H_
#include <Python.h>
void *f2py_pointer(PyObject *obj)
{
#if PY_VERSION_HEX < 0x03000000
if (PyCObject_Check(obj)) {
return PyCObject_AsVoidPtr(obj);
}
#endif
#if PY_VERSION_HEX >= 0x02070000
if (PyCapsule_CheckExact(obj)) {
return PyCapsule_GetPointer(obj, NULL);
}
#endif
PyErr_SetString(PyExc_ValueError, "Not an object containing a void ptr");
return NULL;
}
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment