Skip to content

Commit

Permalink
Fix memory synchronization between MPS and MPSGraph
Browse files Browse the repository at this point in the history
Ideally, I should be able to use MTLFence to do the same. However, it is
not exposed from either MPS and MPSGraph. This hack allows us to encode
the MPSGraph access and MPS on the same command buffer, therefore, make
sure the underlying MPSGraph would synchronize properly for the memory
access.

This hack now enabled us to use SD v2.1 768-v with upcast fix and keep
memory usage down as well.
  • Loading branch information
liuliu committed Dec 11, 2022
1 parent 58ac8b1 commit 82e9c1a
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions lib/nnc/cmd/softmax/mps/ccv_nnc_softmax_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@ static int _ccv_nnc_softmax_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t h
if (resultMatrix != inputMatrix)
[resultMatrix release];
[softmax release];
// Encode a dummy MPSGraph that access the softmax result in the same command buffer. This way,
// the MPSGraph internally will synchronize properly with the underlying MTLHeap allocated
// data. Did some simple performance analysis, this won't impact performance on M2 chips (
// haven't tested on M1 / A* chips). I think this is also a fix for MPSGraph / MPSGEMM interactions
// that blocks mixing them together on 2.0 models. Will do more performance analysis to get to
// the bottom of it.
ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, hint, flags, inputs, input_size, outputs, output_size);
int indices[1];
MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
MPSGraphTensor* mps_input_a;
MPSGraphTensor* mps_a = ccv_nnc_mps_graph_tensor_input(graph, b, b->info.dim, b->stride, &mps_input_a);
[inputTensors addObject:mps_input_a];
MPSGraphShapedType* mps_a_shape = ccv_nnc_mps_graph_tensor_input_shape(a, a->info.dim, a->stride);
[inputShapedTypes addObject:mps_a_shape];
MPSGraphTensor* mps_min = [graph constantWithScalar:0 dataType:ccv_nnc_mps_datatype(b->info.datatype)];
MPSGraphTensor* mps_b = [graph maximumWithPrimaryTensor:mps_a secondaryTensor:mps_min name:nil];
[resultTensors addObject:mps_b];
});
MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data(b, b->info.dim, b->stride);
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data_a], &b, (int*[]){ b->info.dim }, (int*[]){ b->stride }, 1);
} else {
// Otherwise, use MPSGraph.
ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, hint, flags, inputs, input_size, outputs, output_size);
Expand Down

0 comments on commit 82e9c1a

Please sign in to comment.