Skip to content

Commit

Permalink
feat: gemma2 samples with accelerated TPU and GPU (#4395)
Browse files Browse the repository at this point in the history
  • Loading branch information
Stepan Rasputny authored Oct 7, 2024
1 parent fc150e3 commit 8957a58
Show file tree
Hide file tree
Showing 6 changed files with 521 additions and 0 deletions.
78 changes: 78 additions & 0 deletions vertexai/gemma2/gemma2_predict_gpu.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package snippets

// [START generativeaionvertexai_gemma2_predict_gpu]
import (
"context"
"fmt"
"io"

"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"

"google.golang.org/protobuf/types/known/structpb"
)

// predictGPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with GPU accelerators.
func predictGPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error {
ctx := context.Background()

// Note: client can be initialized in the following way:
// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
// if err != nil {
// return fmt.Errorf("unable to create prediction client: %v", err)
// }
// defer client.Close()

gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
prompt := "Why is the sky blue?"
parameters := map[string]interface{}{
"temperature": 0.9,
"maxOutputTokens": 1024,
"topP": 1.0,
"topK": 1,
}

// Encapsulate the prompt in a correct format for TPUs.
// Pay attention that prompt should be set in "inputs" field.
// Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.9}}]
promptValue, err := structpb.NewValue(map[string]interface{}{
"inputs": prompt,
"parameters": parameters,
})
if err != nil {
fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
return err
}

req := &aiplatformpb.PredictRequest{
Endpoint: gemma2Endpoint,
Instances: []*structpb.Value{promptValue},
}

resp, err := client.Predict(ctx, req)
if err != nil {
return err
}

prediction := resp.GetPredictions()
value := prediction[0].GetStringValue()
fmt.Fprintf(w, "%v", value)

return nil
}

// [END generativeaionvertexai_gemma2_predict_gpu]
77 changes: 77 additions & 0 deletions vertexai/gemma2/gemma2_predict_tpu.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package snippets

// [START generativeaionvertexai_gemma2_predict_tpu]
import (
"context"
"fmt"
"io"

"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"

"google.golang.org/protobuf/types/known/structpb"
)

// predictTPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accelerators.
func predictTPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error {
ctx := context.Background()

// Note: client can be initialized in the following way:
// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
// if err != nil {
// return fmt.Errorf("unable to create prediction client: %v", err)
// }
// defer client.Close()

gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
prompt := "Why is the sky blue?"
parameters := map[string]interface{}{
"temperature": 0.9,
"maxOutputTokens": 1024,
"topP": 1.0,
"topK": 1,
}

// Encapsulate the prompt in a correct format for TPUs.
// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
promptValue, err := structpb.NewValue(map[string]interface{}{
"prompt": prompt,
"parameters": parameters,
})
if err != nil {
fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
return err
}

req := &aiplatformpb.PredictRequest{
Endpoint: gemma2Endpoint,
Instances: []*structpb.Value{promptValue},
}

resp, err := client.Predict(ctx, req)
if err != nil {
return err
}

prediction := resp.GetPredictions()
value := prediction[0].GetStringValue()
fmt.Fprintf(w, "%v", value)

return nil
}

// [END generativeaionvertexai_gemma2_predict_tpu]
63 changes: 63 additions & 0 deletions vertexai/gemma2/gemma2_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package snippets

import (
"bytes"
"context"
"strings"
"testing"

"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
"github.com/GoogleCloudPlatform/golang-samples/internal/testutil"
"github.com/googleapis/gax-go/v2"
)

func TestPredictGemma2(t *testing.T) {
tc := testutil.SystemTest(t)

projectID := tc.ProjectID
var buf bytes.Buffer
client := PredictionsClient{}

t.Run("GPU predict", func(t *testing.T) {
buf.Reset()
// Mock ID used to check if GPU was called
if err := predictGPU(&buf, client, projectID, GPUEndpointRegion, GPUEndpointID); err != nil {
t.Fatal(err)
}

if got := buf.String(); !strings.Contains(got, "Rayleigh scattering") {
t.Error("generated text content not found in response")
}
})

t.Run("TPU predict", func(t *testing.T) {
buf.Reset()
// Mock ID used to check if TPU was called
if err := predictTPU(&buf, client, projectID, TPUEndpointRegion, TPUEndpointID); err != nil {
t.Fatal(err)
}

if got := buf.String(); !strings.Contains(got, "Rayleigh scattering") {
t.Error("generated text content not found in response")
}
})
}

type PredictClientInterface interface {
Close() error
Predict(ctx context.Context, req *aiplatformpb.PredictRequest, opts ...gax.CallOption) (*aiplatformpb.PredictResponse, error)
}
45 changes: 45 additions & 0 deletions vertexai/gemma2/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
module github.com/GoogleCloudPlatform/golang-samples/gemma2

go 1.21

require (
cloud.google.com/go/aiplatform v1.68.0
github.com/GoogleCloudPlatform/golang-samples v0.0.0-20240918200157-a00ca430a14b
github.com/googleapis/gax-go/v2 v2.13.0
google.golang.org/protobuf v1.34.2
)

require (
cloud.google.com/go v0.115.1 // indirect
cloud.google.com/go/auth v0.9.3 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect
cloud.google.com/go/compute/metadata v0.5.0 // indirect
cloud.google.com/go/iam v1.2.0 // indirect
cloud.google.com/go/longrunning v0.6.0 // indirect
cloud.google.com/go/storage v1.43.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/s2a-go v0.1.8 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
go.opentelemetry.io/otel v1.29.0 // indirect
go.opentelemetry.io/otel/metric v1.29.0 // indirect
go.opentelemetry.io/otel/trace v1.29.0 // indirect
golang.org/x/crypto v0.27.0 // indirect
golang.org/x/net v0.29.0 // indirect
golang.org/x/oauth2 v0.23.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
golang.org/x/time v0.6.0 // indirect
google.golang.org/api v0.197.0 // indirect
google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/grpc v1.66.1 // indirect
)
Loading

0 comments on commit 8957a58

Please sign in to comment.