#ifndef PASTELMATLAB_MATLAB_ARGUMENT_HPP
#define PASTELMATLAB_MATLAB_ARGUMENT_HPP
#include "pastel/matlab/matlab_argument.h"
#include "pastel/sys/ensure.h"
#include "pastel/sys/sequence/copy_n.h"
namespace Pastel
{
template <typename Type>
Type* matlabCreateScalar(
mxArray*& output)
{
output = mxCreateNumericMatrix(1, 1,
typeToMatlabClassId<Type>(), mxREAL);
return (Type*)mxGetData(output);
}
template <typename Type>
Array<Type> matlabCreateArray(
const Vector2i& extent,
mxArray*& output)
{
ENSURE(allGreaterEqual(extent, 0));
output = mxCreateNumericMatrix(extent.y(), extent.x(),
typeToMatlabClassId<Type>(), mxREAL);
Type* rawData = (Type*)mxGetData(output);
Array<Type> result(extent,
withAliasing(rawData),
StorageOrder::ColumnMajor);
return result;
}
template <typename Type>
arma::Mat<Type> matlabCreateMatrix(
integer height, integer width,
mxArray*& output)
{
ENSURE_OP(width, >=, 0);
ENSURE_OP(height, >=, 0);
output = mxCreateNumericMatrix(height, width,
typeToMatlabClassId<Type>(), mxREAL);
return matlabAsMatrix<Type>(output);
}
template <typename Type>
Array<Type> matlabCreateArray(
integer width, integer height,
mxArray*& output)
{
return matlabCreateArray<Type>(
Vector2i(width, height), output);
}
template <
typename To_Type,
typename From_Type>
Array<To_Type> matlabCreateArray(
const arma::Mat<From_Type>& from,
mxArray*& output)
{
Array<To_Type> to = matlabCreateArray<To_Type>(
Vector2i(from.n_cols, from.n_rows),
output);
std::copy(from.begin(), from.end(), to.begin());
return to;
}
template <typename Type>
Type matlabAsScalar(const mxArray* input,
integer index)
{
ENSURE(mxIsNumeric(input));
integer n = mxGetNumberOfElements(input);
ENSURE_OP(index, >=, 0);
ENSURE_OP(index, <, n);
Type result = 0;
switch(mxGetClassID(input))
{
case mxSINGLE_CLASS:
result = *((real32*)mxGetData(input) + index);
break;
case mxDOUBLE_CLASS:
result = *((real64*)mxGetData(input) + index);
break;
case mxINT8_CLASS:
result = *((int8*)mxGetData(input) + index);
break;
case mxUINT8_CLASS:
result = *((uint8*)mxGetData(input) + index);
break;
case mxINT16_CLASS:
result = *((int16*)mxGetData(input) + index);
break;
case mxUINT16_CLASS:
result = *((uint16*)mxGetData(input) + index);
break;
case mxINT32_CLASS:
result = *((int32*)mxGetData(input) + index);
break;
case mxUINT32_CLASS:
result = *((uint32*)mxGetData(input) + index);
break;
case mxINT64_CLASS:
result = *((int64*)mxGetData(input) + index);
break;
case mxUINT64_CLASS:
result = *((uint64*)mxGetData(input) + index);
break;
default:
// This should not be possible, since
// the above covers all numeric types.
{
bool reachedHere = true;
ENSURE(!reachedHere);
}
break;
};
return result;
}
inline std::string matlabAsString(const mxArray* input)
{
ENSURE(mxIsChar(input));
char* text = mxArrayToString(input);
std::string result(text);
mxFree(text);
return result;
}
namespace MatlabStringAsEnum_
{
template <
typename Type,
typename... ArgumentSet>
Type matlabStringAsEnum(
const std::string input)
{
bool unknownEnum = true;
ENSURE(!unknownEnum);
return Type();
}
template <
typename Type,
typename... ArgumentSet>
Type matlabStringAsEnum(
const std::string input,
const std::string& key,
NoDeduction<Type> value,
ArgumentSet&&... argumentSet)
{
if (input == key)
{
return value;
}
return matlabStringAsEnum<Type>(input,
std::forward<ArgumentSet>(argumentSet)...);
}
}
template <
typename Type,
typename... ArgumentSet>
Type matlabStringAsEnum(
const mxArray* input,
ArgumentSet&&... argumentSet)
{
return MatlabStringAsEnum_::matlabStringAsEnum<Type>(
matlabAsString(input),
std::forward<ArgumentSet>(argumentSet)...);
}
template <typename Type>
Array<Type> matlabAsArray(
const mxArray* that)
{
ENSURE(mxIsNumeric(that));
integer width = mxGetN(that);
integer height = mxGetM(that);
if (typeToMatlabClassId<Type>() == mxGetClassID(that))
{
// The type of the array matches the requested
// type. Aliase the existing data.
return Array<Type>(
Vector2i(width, height),
withAliasing((Type*)mxGetData(that)),
StorageOrder::ColumnMajor);
}
// Copy the data into an array of the required type.
Array<Type> result(
Vector2i(width, height),
0,
StorageOrder::ColumnMajor);
matlabGetScalars(that, result.begin());
return result;
}
template <typename Type>
arma::Mat<Type> matlabAsMatrix(
const mxArray* that)
{
ENSURE(mxIsNumeric(that));
integer m = mxGetM(that);
integer n = mxGetN(that);
if (typeToMatlabClassId<Type>() == mxGetClassID(that))
{
// This controls whether the matrix can be
// reallocated to different sizes. At first I
// had it set to true. However, this caused
// problems later on, because I wanted to use
// an empty matrix on the Matlab side to
// mean that the default should be used in the
// C++ side. Then I could not do Q.eye(d, d),
// for example. So we allow reallocations.
bool strict = false;
// The type of the array matches the requested
// type. Aliase the existing data.
return arma::Mat<Type>(
// Aliase the existing data.
(Type*)mxGetData(that),
m, n,
// Use Matlab's memory for the matrix.
false,
strict);
}
// Copy the data into an array of the required type.
arma::Mat<Type> result(m, n);
matlabGetScalars(that, result.begin());
return result;
}
template <typename Type>
Array<Type> matlabAsLinearizedArray(
const mxArray* that)
{
ENSURE(mxIsNumeric(that));
integer n = mxGetNumberOfElements(that);
Array<Type> result;
if (typeToMatlabClassId<Type>() == mxGetClassID(that))
{
// No copying is done here. Rather, we aliase
// the existing data.
Type* rawData = (Type*)mxGetData(that);
result = Array<Type>(
Vector2i(n, 1),
withAliasing(rawData));
}
else
{
// Copy the data into an array of the required type.
if (n >= (1 << 14))
{
std::cout << "Warning: Copying a large amount of data "
<< "because of type mismatch. Using a matching type, "
<< "if possible, avoids any copying."
<< std::endl;
}
result = Array<Type>(Vector2i(n, 1));
matlabGetScalars(that, result.begin());
}
return result;
}
template <typename Type, typename Array_Output>
integer matlabGetArrays(
const mxArray* cellArray,
Array_Output report)
{
ENSURE(mxIsCell(cellArray));
integer n =
mxGetNumberOfElements(cellArray);
for (integer i = 0;i < n;++i)
{
const mxArray* cell = mxGetCell(cellArray, i);
report(matlabAsArray<Type>(cell));
}
return n;
}
template <typename Scalar_Iterator>
integer matlabGetScalars(
const mxArray* input,
Scalar_Iterator output,
integer offset)
{
ENSURE(mxIsNumeric(input));
ENSURE_OP(offset, >=, 0);
integer n =
mxGetNumberOfElements(input);
if (offset >= n)
{
return n;
}
integer m = n - offset;
switch(mxGetClassID(input))
{
case mxSINGLE_CLASS:
copy_n((real32*)mxGetData(input) + offset, m, output);
break;
case mxDOUBLE_CLASS:
copy_n((real64*)mxGetData(input) + offset, m, output);
break;
case mxINT8_CLASS:
copy_n((int8*)mxGetData(input) + offset, m, output);
break;
case mxUINT8_CLASS:
copy_n((uint8*)mxGetData(input) + offset, m, output);
break;
case mxINT16_CLASS:
copy_n((int16*)mxGetData(input) + offset, m, output);
break;
case mxUINT16_CLASS:
copy_n((uint16*)mxGetData(input) + offset, m, output);
break;
case mxINT32_CLASS:
copy_n((int32*)mxGetData(input) + offset, m, output);
break;
case mxUINT32_CLASS:
copy_n((uint32*)mxGetData(input) + offset, m, output);
break;
case mxINT64_CLASS:
copy_n((int64*)mxGetData(input) + offset, m, output);
break;
case mxUINT64_CLASS:
copy_n((uint64*)mxGetData(input) + offset, m, output);
break;
default:
// This should not be possible, since
// the above covers all numeric types.
ENSURE(false);
break;
}
return n;
}
template <typename Type>
mxClassID typeToMatlabClassId()
{
PASTEL_STATIC_ASSERT(
std::is_pointer<Type>::value ||
std::is_integral<Type>::value ||
std::is_floating_point<Type>::value)
if (std::is_pointer<Type>::value)
{
switch(sizeof(Type))
{
case 8:
return mxUINT64_CLASS;
case 4:
return mxUINT32_CLASS;
};
}
// Note: the mxCHAR_CLASS coincides in type with
// mxINT8_CLASS. Therefore, we leave it out here:
// the important thing is to support the
// numeric types. The mxCHAR_CLASS is handled
// specially for strings.
if (std::is_integral<Type>::value)
{
if (std::is_signed<Type>::value)
{
switch(sizeof(Type))
{
case 8:
return mxINT64_CLASS;
case 4:
return mxINT32_CLASS;
case 2:
return mxINT16_CLASS;
case 1:
return mxINT8_CLASS;
};
}
else
{
switch(sizeof(Type))
{
case 8:
return mxUINT64_CLASS;
case 4:
return mxUINT32_CLASS;
case 2:
return mxUINT16_CLASS;
case 1:
return mxUINT8_CLASS;
};
}
}
if (std::is_floating_point<Type>::value)
{
if (std::is_same<float, typename std::remove_cv<Type>::type>::value)
{
return mxSINGLE_CLASS;
}
if (std::is_same<double, typename std::remove_cv<Type>::type>::value)
{
return mxDOUBLE_CLASS;
}
}
bool reachedHere = true;
ENSURE(!reachedHere);
return mxUNKNOWN_CLASS;
}
}
#endif