Skip to content

Disparate Performance between Python and Java #602

@ryanhausen

Description

@ryanhausen

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04 x86_64): Ubuntu 24.04 x86_64
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 1.0.0
  • Java version (i.e., the output of java -version): openjdk version "21.0.6" 2025-01-21
  • Java command line flags (e.g., GC parameters):
  • Python version (if transferring a model trained in Python): 3.12.8
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: 12.8.61/8905
  • GPU model and memory: V100 (32GB)

Describe the current behavior

Executing the exported model using Tensorflow in Python takes significantly less time than when calling the same function from using Tensorflow Java. I suspect that I am just not using the Java API correctly, because a small change to the python can lead to comparably poor performance in the python.

Describe the expected behavior

The function calls should take a comparable amount of time.

Code to reproduce the issue

I have the following python function:

@tf.function(
    input_signature=[
            tf.TensorSpec(shape=[41, 2048, 2048], dtype=tf.float32, name="data"),  # [k, n, n]
            tf.TensorSpec(shape=[1, 2048, 2048], dtype=tf.float32, name="image"),  # [1, n, n]
            tf.TensorSpec(shape=[41, 2048, 2048], dtype=tf.float32, name="psf"),  # [k, n, n]
    ],
    jit_compile=True
)
def rl_step(
    data: tf.Tensor,  # [k, n, n]
    image: tf.Tensor, # [1, n, n]
    psf: tf.Tensor,   # [k, n, n]
) -> tf.Tensor: # [k, n, n]
    psf_fft = tf.signal.rfft2d(psf)
    psft_fft = tf.signal.rfft2d(tf.reverse(psf, axis=(-2, -1)))
    denom = tf.reduce_sum(
        tf.signal.irfft2d(psf_fft * tf.signal.rfft2d(data)),
        axis=0,
        keepdims=True
    )
    img_err = image / denom
    return data * tf.signal.irfft2d(tf.signal.rfft2d(img_err) * psft_fft)

In python, this function is applied iteratively over the same tensor as below:

    image_tensor = tf.constant(image) # [k, n, n]
    measured_psf_tensor = tf.constant(measured_psf) # [1, n, n]
    data_tensor = tf.constant(data) # [k, n, n]

    for i in range(10):
        start = time()
        data = rl_step(data_tensor, image_tensor, measured_psf_tensor)
        print(f"Iter {i}:", time() - start, "seconds.")

Here image, measured_psf, and data are all 3D arrays with dtype=float32 and n=2048 and k=41

This prints timings around the following:

Iter 0: 0.2061774730682373 seconds.
Iter 1: 0.004193544387817383 seconds.
Iter 2: 0.0007469654083251953 seconds.
Iter 3: 0.000415802001953125 seconds.
Iter 4: 0.0004220008850097656 seconds.
Iter 5: 0.0004246234893798828 seconds.
Iter 6: 0.0004112720489501953 seconds.
Iter 7: 0.00042128562927246094 seconds.
Iter 8: 0.0004055500030517578 seconds.
Iter 9: 0.00040721893310546875 seconds.

I tried exporting the model by adding the following after the timing code:

    mod = tf.Module()
    mod.f = rl_step
    tf.saved_model.save(mod, "pure_tf_export")

Now I tried to use this exported mode from the Java API,

        String modelLocation = "./pure_tf_export";
        try(Graph g = new Graph(); Session s = new Session(g)){
            SavedModelBundle model = SavedModelBundle.loader(modelLocation).load();

            try (Tensor imageTensor = TFloat32.tensorOf(image);
                Tensor psfTensor = TFloat32.tensorOf(psf);
                Tensor dataTensor = TFloat32.tensorOf(data)
            ){
                Map<String, Tensor> inputs = new HashMap<String, Tensor>();
                inputs.put("data", dataTensor);
                inputs.put("image", imageTensor);
                inputs.put("psf", psfTensor);

                for (int i = 0; i < 10; i++){

                    Instant start = Instant.now();

                    Result result = model.function("serving_default").call(inputs);
                    inputs.replace("data", result.get("output_0").get());

                    System.out.println("Iter " + i + " " + (Duration.between(start, Instant.now()).toMillis()/1000f) + " seconds");
                }
            }
        }

And I get timings as follows:

Iter 0 0.701 seconds
Iter 1 0.528 seconds
Iter 2 0.874 seconds
Iter 3 0.224 seconds
Iter 4 0.254 seconds
Iter 5 1.622 seconds
Iter 6 0.241 seconds
Iter 7 0.224 seconds
Iter 8 0.231 seconds
Iter 9 0.228 seconds

I am pretty sure I am making a simple mistake somewhere. I suspect it is in how I am instantiating the Tensors. I know in python if you don't use tf.constant the timings go up a lot.

Any help would be very much appreciated. I tried looking through the documentation and the tensorflow java-examples repository, but couldn't spot what I am doing wrong.

Thanks again!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions