Skip to content

Commit

Permalink
Remove constraints on dataof for softmax.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Oct 20, 2022
1 parent 9e65f90 commit f5ef19c
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions lib/nnc/cmd/softmax/mps/ccv_nnc_softmax_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@ static int _ccv_nnc_softmax_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t h
MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_get_command_buffer(stream_context);
const int a_nd = ccv_nnc_tensor_nd(a->info.dim);
const int b_nd = ccv_nnc_tensor_nd(b->info.dim);
if (a_nd <= 2 && b_nd <= 2 && a->dataof == 0 && b->dataof == 0) // Simple case, we use MPS directly.
if (a_nd <= 2 && b_nd <= 2) // Simple case, we use MPS directly.
{
assert(a_nd > 0);
assert(b_nd > 0);
id<MTLBuffer> a_buffer = mpgetbuffer((ccv_nnc_tensor_t*)a);
const int a_rows = a_nd == 1 ? 1 : a->info.dim[0];
const int a_cols = a_nd == 1 ? a->info.dim[0] : a->info.dim[1];
const size_t a_row_bytes = (CCV_IS_TENSOR_VIEW(a) && a_nd == 2) ? a->stride[0] : CCV_GET_DATA_TYPE_SIZE(a->info.datatype) * a_cols;
MPSMatrix* inputMatrix = [[MPSMatrix alloc] initWithBuffer:a_buffer descriptor:[MPSMatrixDescriptor matrixDescriptorWithRows:a_rows columns:a_cols rowBytes:a_row_bytes dataType:ccv_nnc_mps_datatype(a->info.datatype)]];
MPSMatrix* inputMatrix = [[MPSMatrix alloc] initWithBuffer:a_buffer offset:a->dataof descriptor:[MPSMatrixDescriptor matrixDescriptorWithRows:a_rows columns:a_cols rowBytes:a_row_bytes dataType:ccv_nnc_mps_datatype(a->info.datatype)]];
id<MTLBuffer> b_buffer = mpgetbuffer((ccv_nnc_tensor_t*)b);
const int b_rows = b_nd == 1 ? 1 : b->info.dim[0];
const int b_cols = b_nd == 1 ? b->info.dim[0] : b->info.dim[1];
const size_t b_row_bytes = (CCV_IS_TENSOR_VIEW(b) && b_nd == 2) ? b->stride[0] : CCV_GET_DATA_TYPE_SIZE(b->info.datatype) * b_cols;
MPSMatrix* resultMatrix = [[MPSMatrix alloc] initWithBuffer:b_buffer descriptor:[MPSMatrixDescriptor matrixDescriptorWithRows:b_rows columns:b_cols rowBytes:b_row_bytes dataType:ccv_nnc_mps_datatype(b->info.datatype)]];
MPSMatrix* resultMatrix = [[MPSMatrix alloc] initWithBuffer:b_buffer offset:b->dataof descriptor:[MPSMatrixDescriptor matrixDescriptorWithRows:b_rows columns:b_cols rowBytes:b_row_bytes dataType:ccv_nnc_mps_datatype(b->info.datatype)]];
MPSMatrixSoftMax* softmax = [[MPSMatrixSoftMax alloc] initWithDevice:ccv_nnc_default_device()];
softmax.options = MPSKernelOptionsAllowReducedPrecision;
[inputMatrix synchronizeOnCommandBuffer:command_buffer];
Expand Down

0 comments on commit f5ef19c

Please sign in to comment.