Skip to content

Commit

Permalink
Fix model.
Browse files Browse the repository at this point in the history
  • Loading branch information
chantera committed Aug 21, 2018
1 parent e3776f1 commit 1e8a843
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 15 deletions.
42 changes: 31 additions & 11 deletions contrib/c/dynet_c/model.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include <dynet_c/config.h>

#include <string>

#include <dynet/devices.h>
#include <dynet/model.h>
#include <dynet_c/internal.h>
#include <dynet_c/model.h>
Expand Down Expand Up @@ -42,6 +45,14 @@ DYNET_C_STATUS dynetGetParameterValues(
return DYNET_C_OK;
} DYNET_C_HANDLE_EXCEPTIONS

DYNET_C_STATUS dynetGetParameterGradients(
dynetParameter_t *param, dynetTensor_t **tensor) try {
DYNET_C_CHECK_NOT_NULL(param);
DYNET_C_CHECK_NOT_NULL(tensor);
*tensor = to_c_ptr(to_cpp_ptr(param)->gradients());
return DYNET_C_OK;
} DYNET_C_HANDLE_EXCEPTIONS

DYNET_C_STATUS dynetSetParameterUpdated(
dynetParameter_t *param, DYNET_C_BOOL b) try {
DYNET_C_CHECK_NOT_NULL(param);
Expand Down Expand Up @@ -116,7 +127,8 @@ DYNET_C_STATUS dynetCreateParameterCollection(
return DYNET_C_OK;
} DYNET_C_HANDLE_EXCEPTIONS

DYNET_C_STATUS dynetDeleteParameterCollection(dynetParameterCollection_t *pc) try {
DYNET_C_STATUS dynetDeleteParameterCollection(
dynetParameterCollection_t *pc) try {
DYNET_C_CHECK_NOT_NULL(pc);
delete to_cpp_ptr(pc);
return DYNET_C_OK;
Expand Down Expand Up @@ -147,32 +159,40 @@ DYNET_C_STATUS dynetGetParameterCollectionWeightDecayLambda(

DYNET_C_STATUS dynetAddParametersToParameterCollection(
dynetParameterCollection_t *pc, const dynetDim_t *d,
const dynetParameterInit_t *init, dynetParameter_t **newobj) try {
const dynetParameterInit_t *init, const char *name, dynetDevice_t *device,
dynetParameter_t **newobj) try {
DYNET_C_CHECK_NOT_NULL(pc);
DYNET_C_CHECK_NOT_NULL(d);
DYNET_C_CHECK_NOT_NULL(newobj);
const std::string name_str = name ? name : "";
dynet::Device *device_ptr = device ?
to_cpp_ptr(device) : dynet::default_device;
if (init) {
*newobj = to_c_ptr_from_value(
to_cpp_ptr(pc)->add_parameters(*to_cpp_ptr(d), *to_cpp_ptr(init)));
*newobj = to_c_ptr_from_value(to_cpp_ptr(pc)->add_parameters(
*to_cpp_ptr(d), *to_cpp_ptr(init), name_str, device_ptr));
} else {
*newobj = to_c_ptr_from_value(
to_cpp_ptr(pc)->add_parameters(*to_cpp_ptr(d)));
*newobj = to_c_ptr_from_value(to_cpp_ptr(pc)->add_parameters(
*to_cpp_ptr(d), name_str, device_ptr));
}
return DYNET_C_OK;
} DYNET_C_HANDLE_EXCEPTIONS

DYNET_C_STATUS dynetAddLookupParametersToParameterCollection(
dynetParameterCollection_t *pc, uint32_t n, const dynetDim_t *d,
const dynetParameterInit_t *init, dynetLookupParameter_t **newobj) try {
const dynetParameterInit_t *init, const char *name, dynetDevice_t *device,
dynetLookupParameter_t **newobj) try {
DYNET_C_CHECK_NOT_NULL(pc);
DYNET_C_CHECK_NOT_NULL(d);
DYNET_C_CHECK_NOT_NULL(newobj);
const std::string name_str = name ? name : "";
dynet::Device *device_ptr = device ?
to_cpp_ptr(device) : dynet::default_device;
if (init) {
*newobj = to_c_ptr_from_value(
to_cpp_ptr(pc)->add_lookup_parameters(n, *to_cpp_ptr(d), *to_cpp_ptr(init)));
*newobj = to_c_ptr_from_value(to_cpp_ptr(pc)->add_lookup_parameters(
n, *to_cpp_ptr(d), *to_cpp_ptr(init), name_str, device_ptr));
} else {
*newobj = to_c_ptr_from_value(
to_cpp_ptr(pc)->add_lookup_parameters(n, *to_cpp_ptr(d)));
*newobj = to_c_ptr_from_value(to_cpp_ptr(pc)->add_lookup_parameters(
n, *to_cpp_ptr(d), name_str, device_ptr));
}
return DYNET_C_OK;
} DYNET_C_HANDLE_EXCEPTIONS
Expand Down
24 changes: 20 additions & 4 deletions contrib/c/dynet_c/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define DYNET_C_MODEL_H_

#include <dynet_c/define.h>
#include <dynet_c/devices.h>
#include <dynet_c/dim.h>
#include <dynet_c/model.h>
#include <dynet_c/param-init.h>
Expand Down Expand Up @@ -61,6 +62,15 @@ DYNET_C_API DYNET_C_STATUS dynetGetParameterDim(
DYNET_C_API DYNET_C_STATUS dynetGetParameterValues(
dynetParameter_t *param, dynetTensor_t **tensor);

/**
* Retrieves internal gradients in the parameter as a tensor.
* @param param Pointer of a handler.
* @param tensor Pointer to receive a tensor of the internal gradients.
* @return Status code.
*/
DYNET_C_API DYNET_C_STATUS dynetGetParameterGradients(
dynetParameter_t *param, dynetTensor_t **tensor);

/**
* Sets update status of the parameter.
* @param param Pointer of a handler.
Expand Down Expand Up @@ -200,25 +210,31 @@ DYNET_C_API DYNET_C_STATUS dynetGetParameterCollectionWeightDecayLambda(
* @param pc Pointer of a handler.
* @param d Pointer of a dim.
* @param init Pointer of an initializer.
* @param name Name of the parameter.
* @param device Pointer of a device.
* @param newobj Pointer to receive a Parameter object.
* @return Status code.
*/
DYNET_C_API DYNET_C_STATUS dynetAddParametersToParameterCollection(
dynetParameterCollection_t *pc, const dynetDim_t *d,
const dynetParameterInit_t *init, dynetParameter_t **newobj);
const dynetParameterInit_t *init, const char *name, dynetDevice_t *device,
dynetParameter_t **newobj);

/**
* Adds a lookup parameter to the ParameterCollection.
* @param pc Pointer of a handler.
* @param n Dimension of each embedding.
* @param d Pointer of a dim.
* @param init Pointer of an initializer.
* @param name Name of the parameter.
* @param device Pointer of a device.
* @param newobj Pointer to receive a LookupParameter object.
* @return Status code.
*/
DYNET_C_API DYNET_C_STATUS dynetAddLookupParametersToParameterCollection(
dynetParameterCollection_t *pc, uint32_t n, const dynetDim_t *d,
const dynetParameterInit_t *init, dynetLookupParameter_t **newobj);
const dynetParameterInit_t *init, const char *name, dynetDevice_t *device,
dynetLookupParameter_t **newobj);

/**
* Adds a subcollection to the ParameterCollection.
Expand All @@ -228,8 +244,8 @@ DYNET_C_API DYNET_C_STATUS dynetAddLookupParametersToParameterCollection(
* @return Status code.
*/
DYNET_C_API DYNET_C_STATUS dynetAddSubcollectionToParameterCollection(
dynetParameterCollection *pc, const char *name,
dynetParameterCollection **newobj);
dynetParameterCollection_t *pc, const char *name,
dynetParameterCollection_t **newobj);

/**
* Gets the total number of tunable parameters in the ParameterCollection.
Expand Down

0 comments on commit 1e8a843

Please sign in to comment.